1use std::{collections::HashSet, fmt, sync::Arc};
4
5use crate::{
6 header::{
7 HeaderMap, HeaderName, HeaderValue, ACCESS_CONTROL_ALLOW_CREDENTIALS,
8 ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_ORIGIN, ACCESS_CONTROL_REQUEST_HEADERS,
9 ACCESS_CONTROL_REQUEST_METHOD, ORIGIN, VARY,
10 },
11 headers::{
12 AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders,
13 HeaderMapExt,
14 },
15 Handler, IntoResponse, Method, Request, RequestExt, Response, Result, StatusCode, Transform,
16};
17
18pub struct Config {
20 max_age: usize,
21 credentials: bool,
22 allow_methods: HashSet<Method>,
23 allow_headers: HashSet<HeaderName>,
24 allow_origins: HashSet<HeaderValue>,
25 expose_headers: HashSet<HeaderName>,
26 origin_verify: Option<Arc<dyn Fn(&HeaderValue) -> bool + Send + Sync>>,
27}
28
29impl Config {
30 #[must_use]
32 pub fn new() -> Self {
33 Self::default()
34 }
35
36 #[must_use]
40 pub const fn max_age(mut self, max_age: usize) -> Self {
41 self.max_age = max_age;
42 self
43 }
44
45 #[must_use]
49 pub const fn credentials(mut self, credentials: bool) -> Self {
50 self.credentials = credentials;
51 self
52 }
53
54 #[must_use]
58 pub fn allow_methods<H>(mut self, allow_methods: H) -> Self
59 where
60 H: IntoIterator,
61 H::Item: TryInto<Method>,
62 {
63 self.allow_methods = allow_methods
64 .into_iter()
65 .map(TryInto::try_into)
66 .filter_map(Result::ok)
67 .collect();
68 self
69 }
70
71 #[must_use]
75 pub fn allow_headers<H>(mut self, allow_headers: H) -> Self
76 where
77 H: IntoIterator,
78 H::Item: TryInto<HeaderName>,
79 {
80 self.allow_headers = allow_headers
81 .into_iter()
82 .map(TryInto::try_into)
83 .filter_map(Result::ok)
84 .collect();
85 self
86 }
87
88 #[must_use]
92 pub fn allow_origins<H>(mut self, allow_origins: H) -> Self
93 where
94 H: IntoIterator,
95 H::Item: TryInto<HeaderValue>,
96 {
97 self.allow_origins = allow_origins
98 .into_iter()
99 .map(TryInto::try_into)
100 .filter_map(Result::ok)
101 .collect();
102 self
103 }
104
105 #[must_use]
109 pub fn expose_headers<H>(mut self, expose_headers: H) -> Self
110 where
111 H: IntoIterator,
112 H::Item: TryInto<HeaderName>,
113 {
114 self.expose_headers = expose_headers
115 .into_iter()
116 .map(TryInto::try_into)
117 .filter_map(Result::ok)
118 .collect();
119 self
120 }
121
122 #[must_use]
124 pub fn origin_verify(
125 mut self,
126 origin_verify: Option<Arc<dyn Fn(&HeaderValue) -> bool + Send + Sync>>,
127 ) -> Self {
128 self.origin_verify = origin_verify;
129 self
130 }
131}
132
133impl Default for Config {
134 fn default() -> Self {
135 Self {
136 max_age: 86400,
137 credentials: false,
138 allow_methods: HashSet::from([
139 Method::GET,
140 Method::POST,
141 Method::HEAD,
142 Method::PUT,
143 Method::DELETE,
144 Method::PATCH,
145 ]),
146 allow_origins: HashSet::from([HeaderValue::from_static("*")]),
147 allow_headers: HashSet::new(),
148 expose_headers: HashSet::new(),
149 origin_verify: None,
150 }
151 }
152}
153
154impl Clone for Config {
155 fn clone(&self) -> Self {
156 Self {
157 max_age: self.max_age,
158 credentials: self.credentials,
159 allow_methods: self.allow_methods.clone(),
160 allow_headers: self.allow_headers.clone(),
161 allow_origins: self.allow_origins.clone(),
162 expose_headers: self.expose_headers.clone(),
163 origin_verify: self.origin_verify.clone(),
164 }
165 }
166}
167
168impl fmt::Debug for Config {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 f.debug_struct("CorsConfig")
171 .field("max_age", &self.max_age)
172 .field("credentials", &self.credentials)
173 .field("allow_methods", &self.allow_methods)
174 .field("allow_headers", &self.allow_headers)
175 .field("allow_origins", &self.allow_origins)
176 .field("expose_headers", &self.expose_headers)
177 .finish_non_exhaustive()
178 }
179}
180
181impl<H> Transform<H> for Config {
182 type Output = CorsMiddleware<H>;
183
184 fn transform(&self, h: H) -> Self::Output {
185 CorsMiddleware {
186 h,
187 acam: self.allow_methods.clone().into_iter().collect(),
188 acah: self.allow_headers.clone().into_iter().collect(),
189 aceh: self.expose_headers.clone().into_iter().collect(),
190 config: self.clone(),
191 }
192 }
193}
194
195#[derive(Debug, Clone)]
197pub struct CorsMiddleware<H> {
198 h: H,
199 config: Config,
200 acam: AccessControlAllowMethods,
201 acah: AccessControlAllowHeaders,
202 aceh: AccessControlExposeHeaders,
203}
204
205#[crate::async_trait]
206impl<H, O> Handler<Request> for CorsMiddleware<H>
207where
208 H: Handler<Request, Output = Result<O>>,
209 O: IntoResponse,
210{
211 type Output = Result<Response>;
212
213 async fn call(&self, req: Request) -> Self::Output {
214 let Some(origin) = req.header(ORIGIN).filter(is_not_empty) else {
215 return self.h.call(req).await.map(IntoResponse::into_response);
216 };
217
218 if !self.config.allow_origins.contains(&origin)
219 || !self
220 .config
221 .origin_verify
222 .as_ref()
223 .map_or(true, |f| (f)(&origin))
224 {
225 return Err(StatusCode::FORBIDDEN.into_error());
226 }
227
228 let mut headers = HeaderMap::new();
229 let mut resp = if req.method() == Method::OPTIONS {
230 if req
232 .header(ACCESS_CONTROL_REQUEST_METHOD)
233 .is_some_and(|method| {
234 self.config.allow_methods.is_empty()
235 || self.config.allow_methods.contains(&method)
236 })
237 {
238 headers.typed_insert(self.acam.clone());
239 } else {
240 return Err((StatusCode::FORBIDDEN, "Invalid Preflight Request").into_error());
241 }
242
243 let (allow_headers, request_headers) = req
244 .header(ACCESS_CONTROL_REQUEST_HEADERS)
245 .map_or((true, None), |hs: HeaderValue| {
246 (
247 hs.to_str()
248 .map(|hs| {
249 hs.split(',')
250 .map(str::as_bytes)
251 .map(HeaderName::from_bytes)
252 .filter_map(Result::ok)
253 .any(|header| self.config.allow_headers.contains(&header))
254 })
255 .unwrap_or(false),
256 Some(hs),
257 )
258 });
259
260 if !allow_headers {
261 return Err((StatusCode::FORBIDDEN, "Invalid Preflight Request").into_error());
262 }
263
264 if self.config.allow_headers.is_empty() {
265 headers.insert(
266 ACCESS_CONTROL_ALLOW_HEADERS,
267 request_headers.unwrap_or(HeaderValue::from_static("*")),
268 );
269 } else {
270 headers.typed_insert(self.acah.clone());
271 }
272
273 StatusCode::NO_CONTENT.into_response()
275 } else {
276 if !self.config.expose_headers.is_empty() {
278 headers.typed_insert(self.aceh.clone());
279 }
280
281 self.h.call(req).await.map(IntoResponse::into_response)?
282 };
283
284 headers.insert(VARY, ORIGIN.into());
286 headers.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
287
288 if self.config.credentials {
289 headers.insert(
290 ACCESS_CONTROL_ALLOW_CREDENTIALS,
291 HeaderValue::from_static("true"),
292 );
293 }
294
295 resp.headers_mut().extend(headers);
296
297 Ok(resp)
298 }
299}
300
301fn is_not_empty(h: &HeaderValue) -> bool {
302 !h.is_empty()
303}