wae_https/middleware/
cors.rs1#![doc = include_str!("readme.md")]
2
3use std::time::Duration;
4
5use http::{HeaderName, HeaderValue, Method, header};
6use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer};
7
8pub struct CorsConfig {
12 pub allowed_origins: Vec<HeaderValue>,
14 pub allowed_methods: Vec<Method>,
16 pub allowed_headers: Vec<HeaderName>,
18 pub allow_credentials: bool,
20 pub max_age: u64,
22}
23
24impl CorsConfig {
25 pub fn new() -> Self {
34 Self {
35 allowed_origins: Vec::new(),
36 allowed_methods: Vec::new(),
37 allowed_headers: Vec::new(),
38 allow_credentials: false,
39 max_age: 600,
40 }
41 }
42
43 pub fn allow_origin(mut self, origin: impl Into<String>) -> Self {
49 if let Ok(value) = HeaderValue::from_str(&origin.into()) {
50 self.allowed_origins.push(value);
51 }
52 self
53 }
54
55 pub fn allow_origins<I, S>(mut self, origins: I) -> Self
61 where
62 I: IntoIterator<Item = S>,
63 S: Into<String>,
64 {
65 self.allowed_origins = origins.into_iter().filter_map(|s| HeaderValue::from_str(&s.into()).ok()).collect();
66 self
67 }
68
69 pub fn allow_method(mut self, method: impl Into<String>) -> Self {
75 if let Ok(m) = Method::from_bytes(method.into().as_bytes()) {
76 self.allowed_methods.push(m);
77 }
78 self
79 }
80
81 pub fn allow_methods<I, S>(mut self, methods: I) -> Self
87 where
88 I: IntoIterator<Item = S>,
89 S: Into<String>,
90 {
91 self.allowed_methods = methods.into_iter().filter_map(|m| Method::from_bytes(m.into().as_bytes()).ok()).collect();
92 self
93 }
94
95 pub fn allow_header(mut self, header_name: impl Into<String>) -> Self {
101 if let Ok(name) = HeaderName::try_from(header_name.into()) {
102 self.allowed_headers.push(name);
103 }
104 self
105 }
106
107 pub fn allow_headers<I, S>(mut self, headers: I) -> Self
113 where
114 I: IntoIterator<Item = S>,
115 S: Into<String>,
116 {
117 self.allowed_headers = headers.into_iter().filter_map(|h| HeaderName::try_from(h.into()).ok()).collect();
118 self
119 }
120
121 pub fn allow_credentials(mut self, allow: bool) -> Self {
131 self.allow_credentials = allow;
132 self
133 }
134
135 pub fn max_age(mut self, seconds: u64) -> Self {
141 self.max_age = seconds;
142 self
143 }
144
145 pub fn into_layer(self) -> CorsLayer {
151 let mut cors = CorsLayer::new();
152
153 cors = if self.allowed_origins.is_empty() {
154 cors.allow_origin(AllowOrigin::any())
155 }
156 else {
157 cors.allow_origin(AllowOrigin::list(self.allowed_origins))
158 };
159
160 cors = if self.allowed_methods.is_empty() {
161 cors.allow_methods(AllowMethods::any())
162 }
163 else {
164 cors.allow_methods(AllowMethods::list(self.allowed_methods))
165 };
166
167 cors = if self.allowed_headers.is_empty() {
168 cors.allow_headers(AllowHeaders::any())
169 }
170 else {
171 cors.allow_headers(AllowHeaders::list(self.allowed_headers))
172 };
173
174 cors = if self.allow_credentials { cors.allow_credentials(true) } else { cors };
175
176 cors.max_age(Duration::from_secs(self.max_age))
177 }
178}
179
180impl Default for CorsConfig {
181 fn default() -> Self {
182 Self::new()
183 }
184}
185
186pub fn cors_permissive() -> CorsLayer {
190 CorsLayer::permissive()
191}
192
193pub fn cors_strict() -> CorsLayer {
197 CorsConfig::new()
198 .allow_methods(["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
199 .allow_headers([header::CONTENT_TYPE.as_str(), header::AUTHORIZATION.as_str(), header::ACCEPT.as_str()])
200 .max_age(3600)
201 .into_layer()
202}