ruvector_security/
cors.rs

1//! CORS (Cross-Origin Resource Sharing) configuration
2//!
3//! Provides configurable CORS policies for production and development.
4
5use std::time::Duration;
6
7/// CORS mode
8#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
9#[serde(rename_all = "lowercase")]
10pub enum CorsMode {
11    /// Restrictive CORS (production default)
12    #[default]
13    Restrictive,
14    /// Permissive CORS (development only)
15    Development,
16    /// Custom CORS configuration
17    Custom,
18}
19
20/// CORS configuration
21#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
22pub struct CorsConfig {
23    /// CORS mode
24    pub mode: CorsMode,
25    /// Allowed origins (for Restrictive/Custom mode)
26    pub allowed_origins: Vec<String>,
27    /// Allowed methods
28    pub allowed_methods: Vec<String>,
29    /// Allowed headers
30    pub allowed_headers: Vec<String>,
31    /// Exposed headers
32    pub exposed_headers: Vec<String>,
33    /// Allow credentials
34    pub allow_credentials: bool,
35    /// Max age in seconds
36    pub max_age_secs: u64,
37}
38
39impl Default for CorsConfig {
40    fn default() -> Self {
41        Self {
42            mode: CorsMode::Restrictive,
43            allowed_origins: vec![],
44            allowed_methods: vec![
45                "GET".to_string(),
46                "POST".to_string(),
47                "PUT".to_string(),
48                "DELETE".to_string(),
49                "OPTIONS".to_string(),
50            ],
51            allowed_headers: vec![
52                "Content-Type".to_string(),
53                "Authorization".to_string(),
54                "X-Request-ID".to_string(),
55            ],
56            exposed_headers: vec!["X-Request-ID".to_string(), "X-RateLimit-Remaining".to_string()],
57            allow_credentials: false,
58            max_age_secs: 3600,
59        }
60    }
61}
62
63impl CorsConfig {
64    /// Create a development CORS configuration (permissive)
65    pub fn development() -> Self {
66        Self {
67            mode: CorsMode::Development,
68            ..Default::default()
69        }
70    }
71
72    /// Create a production CORS configuration
73    pub fn production(allowed_origins: Vec<String>) -> Self {
74        Self {
75            mode: CorsMode::Restrictive,
76            allowed_origins,
77            allow_credentials: true,
78            ..Default::default()
79        }
80    }
81
82    /// Add an allowed origin
83    pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
84        self.allowed_origins.push(origin.into());
85        self
86    }
87
88    /// Set allowed methods
89    pub fn allow_methods(mut self, methods: Vec<String>) -> Self {
90        self.allowed_methods = methods;
91        self
92    }
93
94    /// Get max age as Duration
95    pub fn max_age(&self) -> Duration {
96        Duration::from_secs(self.max_age_secs)
97    }
98
99    /// Check if an origin is allowed
100    pub fn is_origin_allowed(&self, origin: &str) -> bool {
101        match self.mode {
102            CorsMode::Development => true,
103            CorsMode::Restrictive | CorsMode::Custom => {
104                self.allowed_origins.iter().any(|allowed| {
105                    if allowed == "*" {
106                        return true;
107                    }
108                    // Support wildcard subdomains like *.example.com
109                    if let Some(suffix) = allowed.strip_prefix("*.") {
110                        return origin.ends_with(suffix)
111                            || origin == format!("https://{}", suffix)
112                            || origin == format!("http://{}", suffix);
113                    }
114                    origin == allowed
115                })
116            }
117        }
118    }
119}
120
121/// Build a tower-http CORS layer from configuration
122#[cfg(feature = "middleware")]
123pub fn build_cors_layer(
124    config: &CorsConfig,
125) -> tower_http::cors::CorsLayer {
126    use http::header::{HeaderName, HeaderValue};
127    use http::Method;
128    use tower_http::cors::CorsLayer;
129
130    match config.mode {
131        CorsMode::Development => CorsLayer::permissive(),
132        CorsMode::Restrictive | CorsMode::Custom => {
133            let mut layer = CorsLayer::new();
134
135            // Set allowed origins
136            if config.allowed_origins.is_empty() {
137                // No origins = block all cross-origin requests
138                layer = layer.allow_origin(tower_http::cors::AllowOrigin::list(std::iter::empty::<HeaderValue>()));
139            } else if config.allowed_origins.iter().any(|o| o == "*") {
140                layer = layer.allow_origin(tower_http::cors::Any);
141            } else {
142                let origins: Vec<HeaderValue> = config
143                    .allowed_origins
144                    .iter()
145                    .filter_map(|o| o.parse().ok())
146                    .collect();
147                layer = layer.allow_origin(origins);
148            }
149
150            // Set allowed methods
151            let methods: Vec<Method> = config
152                .allowed_methods
153                .iter()
154                .filter_map(|m| m.parse().ok())
155                .collect();
156            layer = layer.allow_methods(methods);
157
158            // Set allowed headers
159            let headers: Vec<HeaderName> = config
160                .allowed_headers
161                .iter()
162                .filter_map(|h| h.parse().ok())
163                .collect();
164            layer = layer.allow_headers(headers);
165
166            // Set exposed headers
167            let exposed: Vec<HeaderName> = config
168                .exposed_headers
169                .iter()
170                .filter_map(|h| h.parse().ok())
171                .collect();
172            layer = layer.expose_headers(exposed);
173
174            // Set credentials
175            if config.allow_credentials {
176                layer = layer.allow_credentials(true);
177            }
178
179            // Set max age
180            layer = layer.max_age(config.max_age());
181
182            layer
183        }
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_development_allows_all() {
193        let config = CorsConfig::development();
194        assert!(config.is_origin_allowed("https://example.com"));
195        assert!(config.is_origin_allowed("http://localhost:3000"));
196    }
197
198    #[test]
199    fn test_restrictive_checks_origins() {
200        let config = CorsConfig::production(vec!["https://app.example.com".to_string()]);
201
202        assert!(config.is_origin_allowed("https://app.example.com"));
203        assert!(!config.is_origin_allowed("https://evil.com"));
204    }
205
206    #[test]
207    fn test_wildcard_subdomain() {
208        let config = CorsConfig::production(vec!["*.example.com".to_string()]);
209
210        assert!(config.is_origin_allowed("https://app.example.com"));
211        assert!(config.is_origin_allowed("https://api.example.com"));
212        assert!(!config.is_origin_allowed("https://example.org"));
213    }
214}