ruvector_security/
cors.rs1use std::time::Duration;
6
7#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
9#[serde(rename_all = "lowercase")]
10pub enum CorsMode {
11 #[default]
13 Restrictive,
14 Development,
16 Custom,
18}
19
20#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
22pub struct CorsConfig {
23 pub mode: CorsMode,
25 pub allowed_origins: Vec<String>,
27 pub allowed_methods: Vec<String>,
29 pub allowed_headers: Vec<String>,
31 pub exposed_headers: Vec<String>,
33 pub allow_credentials: bool,
35 pub max_age_secs: u64,
37}
38
39impl Default for CorsConfig {
40 fn default() -> Self {
41 Self {
42 mode: CorsMode::Restrictive,
43 allowed_origins: vec![],
44 allowed_methods: vec![
45 "GET".to_string(),
46 "POST".to_string(),
47 "PUT".to_string(),
48 "DELETE".to_string(),
49 "OPTIONS".to_string(),
50 ],
51 allowed_headers: vec![
52 "Content-Type".to_string(),
53 "Authorization".to_string(),
54 "X-Request-ID".to_string(),
55 ],
56 exposed_headers: vec!["X-Request-ID".to_string(), "X-RateLimit-Remaining".to_string()],
57 allow_credentials: false,
58 max_age_secs: 3600,
59 }
60 }
61}
62
63impl CorsConfig {
64 pub fn development() -> Self {
66 Self {
67 mode: CorsMode::Development,
68 ..Default::default()
69 }
70 }
71
72 pub fn production(allowed_origins: Vec<String>) -> Self {
74 Self {
75 mode: CorsMode::Restrictive,
76 allowed_origins,
77 allow_credentials: true,
78 ..Default::default()
79 }
80 }
81
82 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
84 self.allowed_origins.push(origin.into());
85 self
86 }
87
88 pub fn allow_methods(mut self, methods: Vec<String>) -> Self {
90 self.allowed_methods = methods;
91 self
92 }
93
94 pub fn max_age(&self) -> Duration {
96 Duration::from_secs(self.max_age_secs)
97 }
98
99 pub fn is_origin_allowed(&self, origin: &str) -> bool {
101 match self.mode {
102 CorsMode::Development => true,
103 CorsMode::Restrictive | CorsMode::Custom => {
104 self.allowed_origins.iter().any(|allowed| {
105 if allowed == "*" {
106 return true;
107 }
108 if let Some(suffix) = allowed.strip_prefix("*.") {
110 return origin.ends_with(suffix)
111 || origin == format!("https://{}", suffix)
112 || origin == format!("http://{}", suffix);
113 }
114 origin == allowed
115 })
116 }
117 }
118 }
119}
120
121#[cfg(feature = "middleware")]
123pub fn build_cors_layer(
124 config: &CorsConfig,
125) -> tower_http::cors::CorsLayer {
126 use http::header::{HeaderName, HeaderValue};
127 use http::Method;
128 use tower_http::cors::CorsLayer;
129
130 match config.mode {
131 CorsMode::Development => CorsLayer::permissive(),
132 CorsMode::Restrictive | CorsMode::Custom => {
133 let mut layer = CorsLayer::new();
134
135 if config.allowed_origins.is_empty() {
137 layer = layer.allow_origin(tower_http::cors::AllowOrigin::list(std::iter::empty::<HeaderValue>()));
139 } else if config.allowed_origins.iter().any(|o| o == "*") {
140 layer = layer.allow_origin(tower_http::cors::Any);
141 } else {
142 let origins: Vec<HeaderValue> = config
143 .allowed_origins
144 .iter()
145 .filter_map(|o| o.parse().ok())
146 .collect();
147 layer = layer.allow_origin(origins);
148 }
149
150 let methods: Vec<Method> = config
152 .allowed_methods
153 .iter()
154 .filter_map(|m| m.parse().ok())
155 .collect();
156 layer = layer.allow_methods(methods);
157
158 let headers: Vec<HeaderName> = config
160 .allowed_headers
161 .iter()
162 .filter_map(|h| h.parse().ok())
163 .collect();
164 layer = layer.allow_headers(headers);
165
166 let exposed: Vec<HeaderName> = config
168 .exposed_headers
169 .iter()
170 .filter_map(|h| h.parse().ok())
171 .collect();
172 layer = layer.expose_headers(exposed);
173
174 if config.allow_credentials {
176 layer = layer.allow_credentials(true);
177 }
178
179 layer = layer.max_age(config.max_age());
181
182 layer
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_development_allows_all() {
193 let config = CorsConfig::development();
194 assert!(config.is_origin_allowed("https://example.com"));
195 assert!(config.is_origin_allowed("http://localhost:3000"));
196 }
197
198 #[test]
199 fn test_restrictive_checks_origins() {
200 let config = CorsConfig::production(vec!["https://app.example.com".to_string()]);
201
202 assert!(config.is_origin_allowed("https://app.example.com"));
203 assert!(!config.is_origin_allowed("https://evil.com"));
204 }
205
206 #[test]
207 fn test_wildcard_subdomain() {
208 let config = CorsConfig::production(vec!["*.example.com".to_string()]);
209
210 assert!(config.is_origin_allowed("https://app.example.com"));
211 assert!(config.is_origin_allowed("https://api.example.com"));
212 assert!(!config.is_origin_allowed("https://example.org"));
213 }
214}