rust_serv/middleware/
cors.rs1use hyper::{HeaderMap, Method};
11use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_CREDENTIALS, ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE};
12use std::sync::Arc;
13
14#[derive(Debug, Clone)]
16pub struct CorsConfig {
17 pub allowed_origins: Vec<String>,
19
20 pub allowed_methods: Vec<Method>,
22
23 pub allowed_headers: Vec<String>,
25
26 pub exposed_headers: Vec<String>,
28
29 pub allow_credentials: bool,
31
32 pub max_age: Option<u64>,
34}
35
36impl Default for CorsConfig {
37 fn default() -> Self {
38 Self {
39 allowed_origins: vec![],
40 allowed_methods: vec![],
41 allowed_headers: vec![],
42 exposed_headers: vec![],
43 allow_credentials: false,
44 max_age: None,
45 }
46 }
47}
48
49pub struct CorsLayer {
51 config: Arc<CorsConfig>,
52}
53
54impl CorsLayer {
55 pub fn new(config: CorsConfig) -> Self {
57 Self {
58 config: Arc::new(config),
59 }
60 }
61
62 fn is_origin_allowed(&self, origin: &str) -> bool {
64 if self.config.allowed_origins.is_empty() {
65 return true;
66 }
67
68 self.config.allowed_origins.iter().any(|allowed| {
69 allowed == "*" || origin == allowed
70 })
71 }
72
73 pub fn add_cors_headers(&self, response_headers: &mut HeaderMap, origin: Option<&str>) {
75 let allow_origin_value = if self.config.allowed_origins.contains(&"*".to_string()) {
77 HeaderValue::from_static("*")
78 } else if let Some(origin) = origin {
79 if self.is_origin_allowed(origin) {
80 HeaderValue::from_str(origin).unwrap_or_else(|_| HeaderValue::from_static("*"))
81 } else {
82 return; }
84 } else {
85 return; };
87
88 response_headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin_value);
89
90 let methods_value = self.config.allowed_methods
92 .iter()
93 .map(|m| m.as_str())
94 .collect::<Vec<&str>>()
95 .join(", ");
96 response_headers.insert(ACCESS_CONTROL_ALLOW_METHODS, HeaderValue::from_bytes(methods_value.as_bytes()).unwrap());
97
98 if !self.config.allowed_headers.is_empty() {
100 let headers_value = self.config.allowed_headers.join(", ");
101 response_headers.insert(ACCESS_CONTROL_ALLOW_HEADERS, HeaderValue::from_bytes(headers_value.as_bytes()).unwrap());
102 }
103
104 response_headers.insert(
106 ACCESS_CONTROL_ALLOW_CREDENTIALS,
107 if self.config.allow_credentials {
108 HeaderValue::from_static("true")
109 } else {
110 HeaderValue::from_static("false")
111 }
112 );
113
114 if !self.config.exposed_headers.is_empty() {
116 let exposed_value = self.config.exposed_headers.join(", ");
117 response_headers.insert(ACCESS_CONTROL_EXPOSE_HEADERS, HeaderValue::from_bytes(exposed_value.as_bytes()).unwrap());
118 }
119
120 if let Some(max_age) = self.config.max_age {
122 response_headers.insert(ACCESS_CONTROL_MAX_AGE, HeaderValue::from_bytes(max_age.to_string().as_bytes()).unwrap());
123 }
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[test]
132 fn test_default_config() {
133 let config = CorsConfig::default();
134
135 assert_eq!(config.allowed_origins.len(), 0);
136 assert_eq!(config.allowed_methods.len(), 0);
137 assert_eq!(config.allowed_headers.len(), 0);
138 assert!(!config.allow_credentials);
139 assert_eq!(config.max_age, None);
140 }
141
142 #[test]
143 fn test_allow_all_origins() {
144 let config = CorsConfig {
145 allowed_origins: vec![],
146 ..Default::default()
147 };
148 let cors = CorsLayer::new(config);
149
150 assert!(cors.is_origin_allowed("https://example.com"));
151 assert!(cors.is_origin_allowed("https://any-origin.com"));
152 assert!(cors.is_origin_allowed("https://localhost:8080"));
153 }
154
155 #[test]
156 fn test_restrict_origins() {
157 let config = CorsConfig {
158 allowed_origins: vec!["https://trusted.com".to_string()],
159 ..Default::default()
160 };
161 let cors = CorsLayer::new(config);
162
163 assert!(cors.is_origin_allowed("https://trusted.com"));
164 assert!(!cors.is_origin_allowed("https://untrusted.com"));
165 assert!(!cors.is_origin_allowed("https://malicious.com"));
166 }
167
168 #[test]
169 fn test_credentials_config() {
170 let config_with_creds = CorsConfig {
171 allow_credentials: true,
172 ..Default::default()
173 };
174 let cors = CorsLayer::new(config_with_creds);
175
176 let mut response_headers = HeaderMap::new();
177 cors.add_cors_headers(&mut response_headers, Some("https://example.com"));
178
179 assert_eq!(
180 response_headers.get(ACCESS_CONTROL_ALLOW_CREDENTIALS).unwrap(),
181 HeaderValue::from_static("true")
182 );
183
184 let config_without_creds = CorsConfig::default();
185 let cors = CorsLayer::new(config_without_creds);
186 let mut response_headers = HeaderMap::new();
187 cors.add_cors_headers(&mut response_headers, Some("https://example.com"));
188
189 assert_eq!(
190 response_headers.get(ACCESS_CONTROL_ALLOW_CREDENTIALS).unwrap(),
191 HeaderValue::from_static("false")
192 );
193 }
194
195 #[test]
196 fn test_max_age_config() {
197 let config = CorsConfig {
198 max_age: Some(3600), ..Default::default()
200 };
201 let cors = CorsLayer::new(config);
202
203 let mut response_headers = HeaderMap::new();
204 cors.add_cors_headers(&mut response_headers, Some("https://example.com"));
205
206 assert_eq!(
207 response_headers.get(ACCESS_CONTROL_MAX_AGE).unwrap(),
208 HeaderValue::from_static("3600")
209 );
210
211 let config_no_max_age = CorsConfig::default();
212 let cors = CorsLayer::new(config_no_max_age);
213 let mut response_headers = HeaderMap::new();
214 cors.add_cors_headers(&mut response_headers, Some("https://example.com"));
215
216 assert!(response_headers.get(ACCESS_CONTROL_MAX_AGE).is_none());
217 }
218
219 #[test]
220 fn test_exposed_headers_config() {
221 let config = CorsConfig {
222 exposed_headers: vec!["X-Custom-Header".to_string(), "X-Another-Header".to_string()],
223 ..Default::default()
224 };
225 let cors = CorsLayer::new(config);
226
227 let mut response_headers = HeaderMap::new();
228 cors.add_cors_headers(&mut response_headers, Some("https://example.com"));
229
230 assert_eq!(
231 response_headers.get(ACCESS_CONTROL_EXPOSE_HEADERS).unwrap(),
232 HeaderValue::from_bytes("X-Custom-Header, X-Another-Header".as_bytes()).unwrap()
233 );
234 }
235
236 #[test]
237 fn test_cors_headers_addition() {
238 let config = CorsConfig::default();
239 let cors = CorsLayer::new(config);
240
241 let mut response_headers = HeaderMap::new();
242 cors.add_cors_headers(&mut response_headers, Some("https://example.com"));
243
244 assert!(response_headers.contains_key(ACCESS_CONTROL_ALLOW_ORIGIN));
245 assert!(response_headers.contains_key(ACCESS_CONTROL_ALLOW_METHODS));
246 assert!(response_headers.contains_key(ACCESS_CONTROL_ALLOW_CREDENTIALS));
247 }
248
249 #[test]
250 fn test_wildcard_origin() {
251 let config = CorsConfig {
252 allowed_origins: vec!["*".to_string()],
253 ..Default::default()
254 };
255 let cors = CorsLayer::new(config);
256
257 assert!(cors.is_origin_allowed("https://any-origin.com"));
258 assert!(cors.is_origin_allowed("https://example.com"));
259 assert!(cors.is_origin_allowed("https://localhost:8080"));
260 }
261}