turul_mcp_aws_lambda/
cors.rs

1//! CORS (Cross-Origin Resource Sharing) support for Lambda MCP servers
2//!
3//! This module provides CORS header injection for Lambda responses, since Tower
4//! middleware cannot be used in the Lambda execution environment.
5
6use std::collections::HashSet;
7
8use http::{HeaderValue, Method};
9use lambda_http::{Body as LambdaBody, Response as LambdaResponse};
10use tracing::debug;
11
12use crate::error::{LambdaError, Result};
13
14/// CORS configuration for Lambda MCP servers
15#[derive(Debug, Clone)]
16pub struct CorsConfig {
17    /// Allowed origins for CORS requests
18    /// Use "*" to allow all origins (not recommended for production)
19    pub allowed_origins: Vec<String>,
20
21    /// Allowed HTTP methods
22    pub allowed_methods: Vec<Method>,
23
24    /// Allowed request headers
25    pub allowed_headers: Vec<String>,
26
27    /// Whether to allow credentials (cookies, authorization headers)
28    pub allow_credentials: bool,
29
30    /// Maximum age for preflight cache (in seconds)
31    pub max_age: Option<u32>,
32
33    /// Headers to expose to the client
34    pub expose_headers: Vec<String>,
35}
36
37impl Default for CorsConfig {
38    fn default() -> Self {
39        Self {
40            allowed_origins: vec!["*".to_string()],
41            allowed_methods: vec![Method::GET, Method::POST, Method::DELETE, Method::OPTIONS],
42            allowed_headers: vec![
43                "Content-Type".to_string(),
44                "Accept".to_string(),
45                "Authorization".to_string(),
46                "Mcp-Session-Id".to_string(),
47                "Mcp-Protocol-Version".to_string(),
48                "Last-Event-ID".to_string(),
49            ],
50            allow_credentials: false,
51            max_age: Some(86400), // 24 hours
52            expose_headers: vec![
53                "Mcp-Session-Id".to_string(),
54                "Mcp-Protocol-Version".to_string(),
55            ],
56        }
57    }
58}
59
60impl CorsConfig {
61    /// Create a CORS config that allows all origins (for development)
62    pub fn allow_all() -> Self {
63        Self::default()
64    }
65
66    /// Create a CORS config for specific origins
67    pub fn for_origins(origins: Vec<String>) -> Self {
68        Self {
69            allowed_origins: origins,
70            ..Default::default()
71        }
72    }
73
74    /// Create a CORS config from environment variables
75    pub fn from_env() -> Self {
76        let allowed_origins = std::env::var("MCP_CORS_ORIGINS")
77            .map(|s| s.split(',').map(|s| s.trim().to_string()).collect())
78            .unwrap_or_else(|_| vec!["*".to_string()]);
79
80        let allow_credentials = std::env::var("MCP_CORS_CREDENTIALS")
81            .map(|s| s.parse().unwrap_or(false))
82            .unwrap_or(false);
83
84        let max_age = std::env::var("MCP_CORS_MAX_AGE")
85            .ok()
86            .and_then(|s| s.parse().ok());
87
88        Self {
89            allowed_origins,
90            allow_credentials,
91            max_age,
92            ..Default::default()
93        }
94    }
95}
96
97/// Inject CORS headers into a Lambda response (generic over body type)
98///
99/// This function adds the appropriate CORS headers based on the configuration
100/// and the incoming request's Origin header.
101pub fn inject_cors_headers<B>(
102    response: &mut lambda_http::Response<B>,
103    config: &CorsConfig,
104    request_origin: Option<&str>,
105) -> Result<()> {
106    debug!("Injecting CORS headers for origin: {:?}", request_origin);
107
108    // Determine allowed origin
109    let allowed_origin = determine_allowed_origin(config, request_origin);
110
111    if let Some(origin) = allowed_origin {
112        response.headers_mut().insert(
113            "Access-Control-Allow-Origin",
114            HeaderValue::from_str(&origin)
115                .map_err(|e| LambdaError::Cors(format!("Invalid origin: {}", e)))?,
116        );
117    }
118
119    // Add allowed methods
120    let methods_str = config
121        .allowed_methods
122        .iter()
123        .map(|m| m.as_str())
124        .collect::<Vec<_>>()
125        .join(", ");
126    response.headers_mut().insert(
127        "Access-Control-Allow-Methods",
128        HeaderValue::from_str(&methods_str)
129            .map_err(|e| LambdaError::Cors(format!("Invalid methods: {}", e)))?,
130    );
131
132    // Add allowed headers
133    if !config.allowed_headers.is_empty() {
134        let headers_str = config.allowed_headers.join(", ");
135        response.headers_mut().insert(
136            "Access-Control-Allow-Headers",
137            HeaderValue::from_str(&headers_str)
138                .map_err(|e| LambdaError::Cors(format!("Invalid headers: {}", e)))?,
139        );
140    }
141
142    // Add exposed headers
143    if !config.expose_headers.is_empty() {
144        let expose_str = config.expose_headers.join(", ");
145        response.headers_mut().insert(
146            "Access-Control-Expose-Headers",
147            HeaderValue::from_str(&expose_str)
148                .map_err(|e| LambdaError::Cors(format!("Invalid expose headers: {}", e)))?,
149        );
150    }
151
152    // Add credentials if allowed
153    if config.allow_credentials {
154        response.headers_mut().insert(
155            "Access-Control-Allow-Credentials",
156            HeaderValue::from_static("true"),
157        );
158    }
159
160    // Add max age for preflight requests
161    if let Some(max_age) = config.max_age {
162        response.headers_mut().insert(
163            "Access-Control-Max-Age",
164            HeaderValue::from_str(&max_age.to_string())
165                .map_err(|e| LambdaError::Cors(format!("Invalid max age: {}", e)))?,
166        );
167    }
168
169    debug!("CORS headers injected successfully");
170    Ok(())
171}
172
173/// Create a CORS preflight response
174///
175/// Handles OPTIONS requests that browsers send before making actual CORS requests.
176pub fn create_preflight_response(
177    config: &CorsConfig,
178    request_origin: Option<&str>,
179) -> Result<LambdaResponse<LambdaBody>> {
180    debug!("Creating CORS preflight response");
181
182    let mut response = LambdaResponse::builder()
183        .status(200)
184        .body(LambdaBody::Empty)
185        .map_err(LambdaError::Http)?;
186
187    inject_cors_headers(&mut response, config, request_origin)?;
188
189    Ok(response)
190}
191
192/// Determine the allowed origin based on configuration and request
193fn determine_allowed_origin(config: &CorsConfig, request_origin: Option<&str>) -> Option<String> {
194    // If wildcard is configured, return it
195    if config.allowed_origins.contains(&"*".to_string()) {
196        return Some("*".to_string());
197    }
198
199    // If no origin in request, no CORS header needed
200    let request_origin = request_origin?;
201
202    // Check if the request origin is in the allowed list
203    if config.allowed_origins.contains(&request_origin.to_string()) {
204        Some(request_origin.to_string())
205    } else {
206        // Origin not allowed, don't set CORS header
207        None
208    }
209}
210
211/// Validate CORS configuration
212pub fn validate_config(config: &CorsConfig) -> Result<()> {
213    // Check for wildcard with credentials (security issue)
214    if config.allow_credentials && config.allowed_origins.contains(&"*".to_string()) {
215        return Err(LambdaError::Cors(
216            "Cannot use wildcard origin (*) with credentials enabled".to_string(),
217        ));
218    }
219
220    // Validate origins are proper URLs or wildcards
221    for origin in &config.allowed_origins {
222        if origin != "*" && !origin.starts_with("http://") && !origin.starts_with("https://") {
223            return Err(LambdaError::Cors(format!(
224                "Invalid origin format: {}",
225                origin
226            )));
227        }
228    }
229
230    // Check for duplicate headers
231    let headers_set: HashSet<_> = config.allowed_headers.iter().collect();
232    if headers_set.len() != config.allowed_headers.len() {
233        return Err(LambdaError::Cors(
234            "Duplicate headers in allowed_headers".to_string(),
235        ));
236    }
237
238    Ok(())
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use lambda_http::Body;
245
246    #[test]
247    fn test_default_config() {
248        let config = CorsConfig::default();
249        assert!(config.allowed_origins.contains(&"*".to_string()));
250        assert!(config.allowed_methods.contains(&Method::GET));
251        assert!(config.allowed_methods.contains(&Method::POST));
252        assert!(config.allowed_headers.contains(&"Content-Type".to_string()));
253    }
254
255    #[test]
256    fn test_config_validation() {
257        let mut config = CorsConfig::default();
258        assert!(validate_config(&config).is_ok());
259
260        // Test invalid wildcard with credentials
261        config.allow_credentials = true;
262        assert!(validate_config(&config).is_err());
263
264        // Test invalid origin format
265        config.allow_credentials = false;
266        config.allowed_origins = vec!["invalid-origin".to_string()];
267        assert!(validate_config(&config).is_err());
268    }
269
270    #[tokio::test]
271    async fn test_cors_headers_injection() {
272        let config = CorsConfig::default();
273        let mut response = LambdaResponse::builder()
274            .status(200)
275            .body(Body::Empty)
276            .unwrap();
277
278        inject_cors_headers(&mut response, &config, Some("https://example.com")).unwrap();
279
280        assert_eq!(
281            response.headers().get("access-control-allow-origin"),
282            Some(&HeaderValue::from_static("*"))
283        );
284
285        assert!(
286            response
287                .headers()
288                .contains_key("access-control-allow-methods")
289        );
290        assert!(
291            response
292                .headers()
293                .contains_key("access-control-allow-headers")
294        );
295    }
296
297    #[tokio::test]
298    async fn test_preflight_response() {
299        let config = CorsConfig::default();
300        let response = create_preflight_response(&config, Some("https://example.com")).unwrap();
301
302        assert_eq!(response.status(), 200);
303        assert!(
304            response
305                .headers()
306                .contains_key("access-control-allow-origin")
307        );
308        assert!(
309            response
310                .headers()
311                .contains_key("access-control-allow-methods")
312        );
313    }
314}