Skip to main content

rust_serv/middleware/
cors.rs

1//! CORS (Cross-Origin Resource Sharing) Middleware
2//!
3//! This module implements CORS support:
4//! - Origin header validation
5//! - Preflight request handling
6//! - CORS headers configuration
7//! - Credentials management
8//! - Request method validation
9
10use hyper::{HeaderMap, Method};
11use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE};
12use std::sync::Arc;
13
14/// CORS configuration options
15#[derive(Debug, Clone)]
16pub struct CorsConfig {
17    /// List of allowed origins ( "*" for all origins)
18    pub allowed_origins: Vec<String>,
19
20    /// Allowed HTTP methods
21    pub allowed_methods: Vec<Method>,
22
23    /// Allowed request headers
24    pub allowed_headers: Vec<String>,
25
26    /// Exposed response headers
27    pub exposed_headers: Vec<String>,
28
29    /// Whether credentials are allowed
30    pub allow_credentials: bool,
31
32    /// Maximum age for preflight results (seconds)
33    pub max_age: Option<u64>,
34}
35
36impl Default for CorsConfig {
37    fn default() -> Self {
38        Self {
39            allowed_origins: vec![],
40            allowed_methods: vec![],
41            allowed_headers: vec![],
42            exposed_headers: vec![],
43            allow_credentials: false,
44            max_age: None,
45        }
46    }
47}
48
49/// CORS layer that handles cross-origin requests
50pub struct CorsLayer {
51    config: Arc<CorsConfig>,
52}
53
54impl CorsLayer {
55    /// Create a new CORS layer
56    pub fn new(config: CorsConfig) -> Self {
57        Self {
58            config: Arc::new(config),
59        }
60    }
61
62    /// Check if origin is allowed
63    fn is_origin_allowed(&self, origin: &str) -> bool {
64        if self.config.allowed_origins.is_empty() {
65            return true;
66        }
67
68        self.config.allowed_origins.iter().any(|allowed| {
69            allowed == "*" || origin == allowed
70        })
71    }
72
73    /// Add CORS headers to response
74    pub fn add_cors_headers(&self, response_headers: &mut HeaderMap, origin: Option<&str>) {
75        // Allow-Origin
76        let allow_origin_value = if self.config.allowed_origins.contains(&"*".to_string()) {
77            HeaderValue::from_static("*")
78        } else if let Some(origin) = origin {
79            if self.is_origin_allowed(origin) {
80                HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("*"))
81            } else {
82                return; // Invalid origin, don't add headers
83            }
84        } else {
85            return; // No origin provided, don't add headers
86        };
87
88        response_headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin_value);
89
90        // Allow-Methods
91        let methods_value = self.config.allowed_methods
92            .iter()
93            .map(|m| m.as_str())
94            .collect::<Vec<&str>>()
95            .join(", ");
96        response_headers.insert(ACCESS_CONTROL_ALLOW_METHODS, HeaderValue::from_bytes(methods_value.as_bytes()).unwrap());
97
98        // Allow-Headers
99        if !self.config.allowed_headers.is_empty() {
100            let headers_value = self.config.allowed_headers.join(", ");
101            response_headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, HeaderValue::from_bytes(headers_value.as_bytes()).unwrap());
102        }
103
104        // Allow-Credentials
105        response_headers.insert(
106            ACCESS_CONTROL_ALLOW_CREDENTIALS,
107            if self.config.allow_credentials {
108                HeaderValue::from_static("true")
109            } else {
110                HeaderValue::from_static("false")
111            }
112        );
113
114        // Expose-Headers
115        if !self.config.exposed_headers.is_empty() {
116            let exposed_value = self.config.exposed_headers.join(", ");
117            response_headers.insert(ACCESS_CONTROL_EXPOSE_HEADERS, HeaderValue::from_bytes(exposed_value.as_bytes()).unwrap());
118        }
119
120        // Max-Age
121        if let Some(max_age) = self.config.max_age {
122            response_headers.insert(ACCESS_CONTROL_MAX_AGE, HeaderValue::from_bytes(max_age.to_string().as_bytes()).unwrap());
123        }
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn test_default_config() {
133        let config = CorsConfig::default();
134
135        assert_eq!(config.allowed_origins.len(), 0);
136        assert_eq!(config.allowed_methods.len(), 0);
137        assert_eq!(config.allowed_headers.len(), 0);
138        assert!(!config.allow_credentials);
139        assert_eq!(config.max_age, None);
140    }
141
142    #[test]
143    fn test_allow_all_origins() {
144        let config = CorsConfig {
145            allowed_origins: vec![],
146            ..Default::default()
147        };
148        let cors = CorsLayer::new(config);
149
150        assert!(cors.is_origin_allowed("https://example.com"));
151        assert!(cors.is_origin_allowed("https://any-origin.com"));
152        assert!(cors.is_origin_allowed("https://localhost:8080"));
153    }
154
155    #[test]
156    fn test_restrict_origins() {
157        let config = CorsConfig {
158            allowed_origins: vec!["https://trusted.com".to_string()],
159            ..Default::default()
160        };
161        let cors = CorsLayer::new(config);
162
163        assert!(cors.is_origin_allowed("https://trusted.com"));
164        assert!(!cors.is_origin_allowed("https://untrusted.com"));
165        assert!(!cors.is_origin_allowed("https://malicious.com"));
166    }
167
168    #[test]
169    fn test_credentials_config() {
170        let config_with_creds = CorsConfig {
171            allow_credentials: true,
172            ..Default::default()
173        };
174        let cors = CorsLayer::new(config_with_creds);
175
176        let mut response_headers = HeaderMap::new();
177        cors.add_cors_headers(&mut response_headers, Some("https://example.com"));
178
179        assert_eq!(
180            response_headers.get(ACCESS_CONTROL_ALLOW_CREDENTIALS).unwrap(),
181            HeaderValue::from_static("true")
182        );
183
184        let config_without_creds = CorsConfig::default();
185        let cors = CorsLayer::new(config_without_creds);
186        let mut response_headers = HeaderMap::new();
187        cors.add_cors_headers(&mut response_headers, Some("https://example.com"));
188
189        assert_eq!(
190            response_headers.get(ACCESS_CONTROL_ALLOW_CREDENTIALS).unwrap(),
191            HeaderValue::from_static("false")
192        );
193    }
194
195    #[test]
196    fn test_max_age_config() {
197        let config = CorsConfig {
198            max_age: Some(3600), // 1 hour
199            ..Default::default()
200        };
201        let cors = CorsLayer::new(config);
202
203        let mut response_headers = HeaderMap::new();
204        cors.add_cors_headers(&mut response_headers, Some("https://example.com"));
205
206        assert_eq!(
207            response_headers.get(ACCESS_CONTROL_MAX_AGE).unwrap(),
208            HeaderValue::from_static("3600")
209        );
210
211        let config_no_max_age = CorsConfig::default();
212        let cors = CorsLayer::new(config_no_max_age);
213        let mut response_headers = HeaderMap::new();
214        cors.add_cors_headers(&mut response_headers, Some("https://example.com"));
215
216        assert!(response_headers.get(ACCESS_CONTROL_MAX_AGE).is_none());
217    }
218
219    #[test]
220    fn test_exposed_headers_config() {
221        let config = CorsConfig {
222            exposed_headers: vec!["X-Custom-Header".to_string(), "X-Another-Header".to_string()],
223            ..Default::default()
224        };
225        let cors = CorsLayer::new(config);
226
227        let mut response_headers = HeaderMap::new();
228        cors.add_cors_headers(&mut response_headers, Some("https://example.com"));
229
230        assert_eq!(
231            response_headers.get(ACCESS_CONTROL_EXPOSE_HEADERS).unwrap(),
232            HeaderValue::from_bytes("X-Custom-Header, X-Another-Header".as_bytes()).unwrap()
233        );
234    }
235
236    #[test]
237    fn test_cors_headers_addition() {
238        let config = CorsConfig::default();
239        let cors = CorsLayer::new(config);
240
241        let mut response_headers = HeaderMap::new();
242        cors.add_cors_headers(&mut response_headers, Some("https://example.com"));
243
244        assert!(response_headers.contains_key(ACCESS_CONTROL_ALLOW_ORIGIN));
245        assert!(response_headers.contains_key(ACCESS_CONTROL_ALLOW_METHODS));
246        assert!(response_headers.contains_key(ACCESS_CONTROL_ALLOW_CREDENTIALS));
247    }
248
249    #[test]
250    fn test_wildcard_origin() {
251        let config = CorsConfig {
252            allowed_origins: vec!["*".to_string()],
253            ..Default::default()
254        };
255        let cors = CorsLayer::new(config);
256
257        assert!(cors.is_origin_allowed("https://any-origin.com"));
258        assert!(cors.is_origin_allowed("https://example.com"));
259        assert!(cors.is_origin_allowed("https://localhost:8080"));
260    }
261}