turul_mcp_aws_lambda/
cors.rs1use std::collections::HashSet;
7
8use http::{HeaderValue, Method};
9use lambda_http::{Body as LambdaBody, Response as LambdaResponse};
10use tracing::debug;
11
12use crate::error::{LambdaError, Result};
13
14#[derive(Debug, Clone)]
16pub struct CorsConfig {
17 pub allowed_origins: Vec<String>,
20
21 pub allowed_methods: Vec<Method>,
23
24 pub allowed_headers: Vec<String>,
26
27 pub allow_credentials: bool,
29
30 pub max_age: Option<u32>,
32
33 pub expose_headers: Vec<String>,
35}
36
37impl Default for CorsConfig {
38 fn default() -> Self {
39 Self {
40 allowed_origins: vec!["*".to_string()],
41 allowed_methods: vec![Method::GET, Method::POST, Method::DELETE, Method::OPTIONS],
42 allowed_headers: vec![
43 "Content-Type".to_string(),
44 "Accept".to_string(),
45 "Authorization".to_string(),
46 "Mcp-Session-Id".to_string(),
47 "Mcp-Protocol-Version".to_string(),
48 "Last-Event-ID".to_string(),
49 ],
50 allow_credentials: false,
51 max_age: Some(86400), expose_headers: vec![
53 "Mcp-Session-Id".to_string(),
54 "Mcp-Protocol-Version".to_string(),
55 ],
56 }
57 }
58}
59
60impl CorsConfig {
61 pub fn allow_all() -> Self {
63 Self::default()
64 }
65
66 pub fn for_origins(origins: Vec<String>) -> Self {
68 Self {
69 allowed_origins: origins,
70 ..Default::default()
71 }
72 }
73
74 pub fn from_env() -> Self {
76 let allowed_origins = std::env::var("MCP_CORS_ORIGINS")
77 .map(|s| s.split(',').map(|s| s.trim().to_string()).collect())
78 .unwrap_or_else(|_| vec!["*".to_string()]);
79
80 let allow_credentials = std::env::var("MCP_CORS_CREDENTIALS")
81 .map(|s| s.parse().unwrap_or(false))
82 .unwrap_or(false);
83
84 let max_age = std::env::var("MCP_CORS_MAX_AGE")
85 .ok()
86 .and_then(|s| s.parse().ok());
87
88 Self {
89 allowed_origins,
90 allow_credentials,
91 max_age,
92 ..Default::default()
93 }
94 }
95}
96
97pub fn inject_cors_headers<B>(
102 response: &mut lambda_http::Response<B>,
103 config: &CorsConfig,
104 request_origin: Option<&str>,
105) -> Result<()> {
106 debug!("Injecting CORS headers for origin: {:?}", request_origin);
107
108 let allowed_origin = determine_allowed_origin(config, request_origin);
110
111 if let Some(origin) = allowed_origin {
112 response.headers_mut().insert(
113 "Access-Control-Allow-Origin",
114 HeaderValue::from_str(&origin)
115 .map_err(|e| LambdaError::Cors(format!("Invalid origin: {}", e)))?,
116 );
117 }
118
119 let methods_str = config
121 .allowed_methods
122 .iter()
123 .map(|m| m.as_str())
124 .collect::<Vec<_>>()
125 .join(", ");
126 response.headers_mut().insert(
127 "Access-Control-Allow-Methods",
128 HeaderValue::from_str(&methods_str)
129 .map_err(|e| LambdaError::Cors(format!("Invalid methods: {}", e)))?,
130 );
131
132 if !config.allowed_headers.is_empty() {
134 let headers_str = config.allowed_headers.join(", ");
135 response.headers_mut().insert(
136 "Access-Control-Allow-Headers",
137 HeaderValue::from_str(&headers_str)
138 .map_err(|e| LambdaError::Cors(format!("Invalid headers: {}", e)))?,
139 );
140 }
141
142 if !config.expose_headers.is_empty() {
144 let expose_str = config.expose_headers.join(", ");
145 response.headers_mut().insert(
146 "Access-Control-Expose-Headers",
147 HeaderValue::from_str(&expose_str)
148 .map_err(|e| LambdaError::Cors(format!("Invalid expose headers: {}", e)))?,
149 );
150 }
151
152 if config.allow_credentials {
154 response.headers_mut().insert(
155 "Access-Control-Allow-Credentials",
156 HeaderValue::from_static("true"),
157 );
158 }
159
160 if let Some(max_age) = config.max_age {
162 response.headers_mut().insert(
163 "Access-Control-Max-Age",
164 HeaderValue::from_str(&max_age.to_string())
165 .map_err(|e| LambdaError::Cors(format!("Invalid max age: {}", e)))?,
166 );
167 }
168
169 debug!("CORS headers injected successfully");
170 Ok(())
171}
172
173pub fn create_preflight_response(
177 config: &CorsConfig,
178 request_origin: Option<&str>,
179) -> Result<LambdaResponse<LambdaBody>> {
180 debug!("Creating CORS preflight response");
181
182 let mut response = LambdaResponse::builder()
183 .status(200)
184 .body(LambdaBody::Empty)
185 .map_err(LambdaError::Http)?;
186
187 inject_cors_headers(&mut response, config, request_origin)?;
188
189 Ok(response)
190}
191
192fn determine_allowed_origin(config: &CorsConfig, request_origin: Option<&str>) -> Option<String> {
194 if config.allowed_origins.contains(&"*".to_string()) {
196 return Some("*".to_string());
197 }
198
199 let request_origin = request_origin?;
201
202 if config.allowed_origins.contains(&request_origin.to_string()) {
204 Some(request_origin.to_string())
205 } else {
206 None
208 }
209}
210
211pub fn validate_config(config: &CorsConfig) -> Result<()> {
213 if config.allow_credentials && config.allowed_origins.contains(&"*".to_string()) {
215 return Err(LambdaError::Cors(
216 "Cannot use wildcard origin (*) with credentials enabled".to_string(),
217 ));
218 }
219
220 for origin in &config.allowed_origins {
222 if origin != "*" && !origin.starts_with("http://") && !origin.starts_with("https://") {
223 return Err(LambdaError::Cors(format!(
224 "Invalid origin format: {}",
225 origin
226 )));
227 }
228 }
229
230 let headers_set: HashSet<_> = config.allowed_headers.iter().collect();
232 if headers_set.len() != config.allowed_headers.len() {
233 return Err(LambdaError::Cors(
234 "Duplicate headers in allowed_headers".to_string(),
235 ));
236 }
237
238 Ok(())
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use lambda_http::Body;
245
246 #[test]
247 fn test_default_config() {
248 let config = CorsConfig::default();
249 assert!(config.allowed_origins.contains(&"*".to_string()));
250 assert!(config.allowed_methods.contains(&Method::GET));
251 assert!(config.allowed_methods.contains(&Method::POST));
252 assert!(config.allowed_headers.contains(&"Content-Type".to_string()));
253 }
254
255 #[test]
256 fn test_config_validation() {
257 let mut config = CorsConfig::default();
258 assert!(validate_config(&config).is_ok());
259
260 config.allow_credentials = true;
262 assert!(validate_config(&config).is_err());
263
264 config.allow_credentials = false;
266 config.allowed_origins = vec!["invalid-origin".to_string()];
267 assert!(validate_config(&config).is_err());
268 }
269
270 #[tokio::test]
271 async fn test_cors_headers_injection() {
272 let config = CorsConfig::default();
273 let mut response = LambdaResponse::builder()
274 .status(200)
275 .body(Body::Empty)
276 .unwrap();
277
278 inject_cors_headers(&mut response, &config, Some("https://example.com")).unwrap();
279
280 assert_eq!(
281 response.headers().get("access-control-allow-origin"),
282 Some(&HeaderValue::from_static("*"))
283 );
284
285 assert!(
286 response
287 .headers()
288 .contains_key("access-control-allow-methods")
289 );
290 assert!(
291 response
292 .headers()
293 .contains_key("access-control-allow-headers")
294 );
295 }
296
297 #[tokio::test]
298 async fn test_preflight_response() {
299 let config = CorsConfig::default();
300 let response = create_preflight_response(&config, Some("https://example.com")).unwrap();
301
302 assert_eq!(response.status(), 200);
303 assert!(
304 response
305 .headers()
306 .contains_key("access-control-allow-origin")
307 );
308 assert!(
309 response
310 .headers()
311 .contains_key("access-control-allow-methods")
312 );
313 }
314}