1use unicase;
3
4pub use self::unicase::Ascii;
5use crate::hosts::{Host, Port};
6use crate::matcher::{Matcher, Pattern};
7use std::collections::HashSet;
8use std::{fmt, ops};
9
10#[derive(Clone, Hash, Debug, PartialEq, Eq)]
12pub enum OriginProtocol {
13 Http,
15 Https,
17 Custom(String),
19}
20
21#[derive(Clone, PartialEq, Eq, Debug, Hash)]
23pub struct Origin {
24 protocol: OriginProtocol,
25 host: Host,
26 as_string: String,
27 matcher: Matcher,
28}
29
30impl<T: AsRef<str>> From<T> for Origin {
31 fn from(string: T) -> Self {
32 Origin::parse(string.as_ref())
33 }
34}
35
36impl Origin {
37 fn with_host(protocol: OriginProtocol, host: Host) -> Self {
38 let string = Self::to_string(&protocol, &host);
39 let matcher = Matcher::new(&string);
40
41 Origin {
42 protocol,
43 host,
44 as_string: string,
45 matcher,
46 }
47 }
48
49 pub fn new<T: Into<Port>>(protocol: OriginProtocol, host: &str, port: T) -> Self {
52 Self::with_host(protocol, Host::new(host, port))
53 }
54
55 pub fn parse(data: &str) -> Self {
58 let mut it = data.split("://");
59 let proto = it.next().expect("split always returns non-empty iterator.");
60 let hostname = it.next();
61
62 let (proto, hostname) = match hostname {
63 None => (None, proto),
64 Some(hostname) => (Some(proto), hostname),
65 };
66
67 let proto = proto.map(str::to_lowercase);
68 let hostname = Host::parse(hostname);
69
70 let protocol = match proto {
71 None => OriginProtocol::Http,
72 Some(ref p) if p == "http" => OriginProtocol::Http,
73 Some(ref p) if p == "https" => OriginProtocol::Https,
74 Some(other) => OriginProtocol::Custom(other),
75 };
76
77 Origin::with_host(protocol, hostname)
78 }
79
80 fn to_string(protocol: &OriginProtocol, host: &Host) -> String {
81 format!(
82 "{}://{}",
83 match *protocol {
84 OriginProtocol::Http => "http",
85 OriginProtocol::Https => "https",
86 OriginProtocol::Custom(ref protocol) => protocol,
87 },
88 &**host,
89 )
90 }
91}
92
93impl Pattern for Origin {
94 fn matches<T: AsRef<str>>(&self, other: T) -> bool {
95 self.matcher.matches(other)
96 }
97}
98
99impl ops::Deref for Origin {
100 type Target = str;
101 fn deref(&self) -> &Self::Target {
102 &self.as_string
103 }
104}
105
106#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum AccessControlAllowOrigin {
109 Value(Origin),
111 Null,
113 Any,
115}
116
117impl fmt::Display for AccessControlAllowOrigin {
118 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
119 write!(
120 f,
121 "{}",
122 match *self {
123 AccessControlAllowOrigin::Any => "*",
124 AccessControlAllowOrigin::Null => "null",
125 AccessControlAllowOrigin::Value(ref val) => val,
126 }
127 )
128 }
129}
130
131impl<T: Into<String>> From<T> for AccessControlAllowOrigin {
132 fn from(s: T) -> AccessControlAllowOrigin {
133 match s.into().as_str() {
134 "all" | "*" | "any" => AccessControlAllowOrigin::Any,
135 "null" => AccessControlAllowOrigin::Null,
136 origin => AccessControlAllowOrigin::Value(origin.into()),
137 }
138 }
139}
140
141#[derive(Debug, Clone, PartialEq)]
143pub enum AccessControlAllowHeaders {
144 Only(Vec<String>),
146 Any,
148}
149
150#[derive(Debug, Clone, PartialEq, Eq)]
152pub enum AllowCors<T> {
153 NotRequired,
155 Invalid,
157 Ok(T),
159}
160
161impl<T> AllowCors<T> {
162 pub fn map<F, O>(self, f: F) -> AllowCors<O>
164 where
165 F: FnOnce(T) -> O,
166 {
167 use self::AllowCors::*;
168
169 match self {
170 NotRequired => NotRequired,
171 Invalid => Invalid,
172 Ok(val) => Ok(f(val)),
173 }
174 }
175}
176
177impl<T> Into<Option<T>> for AllowCors<T> {
178 fn into(self) -> Option<T> {
179 use self::AllowCors::*;
180
181 match self {
182 NotRequired | Invalid => None,
183 Ok(header) => Some(header),
184 }
185 }
186}
187
188pub fn get_cors_allow_origin(
190 origin: Option<&str>,
191 host: Option<&str>,
192 allowed: &Option<Vec<AccessControlAllowOrigin>>,
193) -> AllowCors<AccessControlAllowOrigin> {
194 match origin {
195 None => AllowCors::NotRequired,
196 Some(ref origin) => {
197 if let Some(host) = host {
198 if origin.ends_with(host) {
200 let origin = Origin::parse(origin);
202 if &*origin.host == host {
203 return AllowCors::NotRequired;
204 }
205 }
206 }
207
208 match allowed.as_ref() {
209 None if *origin == "null" => AllowCors::Ok(AccessControlAllowOrigin::Null),
210 None => AllowCors::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))),
211 Some(ref allowed) if *origin == "null" => allowed
212 .iter()
213 .find(|cors| **cors == AccessControlAllowOrigin::Null)
214 .cloned()
215 .map(AllowCors::Ok)
216 .unwrap_or(AllowCors::Invalid),
217 Some(ref allowed) => allowed
218 .iter()
219 .find(|cors| match **cors {
220 AccessControlAllowOrigin::Any => true,
221 AccessControlAllowOrigin::Value(ref val) if val.matches(origin) => true,
222 _ => false,
223 })
224 .map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin)))
225 .map(AllowCors::Ok)
226 .unwrap_or(AllowCors::Invalid),
227 }
228 }
229 }
230}
231
232pub fn get_cors_allow_headers<T: AsRef<str>, O, F: Fn(T) -> O>(
234 mut headers: impl Iterator<Item = T>,
235 requested_headers: impl Iterator<Item = T>,
236 cors_allow_headers: &AccessControlAllowHeaders,
237 to_result: F,
238) -> AllowCors<Vec<O>> {
239 if let AccessControlAllowHeaders::Only(only) = cors_allow_headers {
241 let are_all_allowed = headers.all(|header| {
242 let name = &Ascii::new(header.as_ref());
243 only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
244 });
245
246 if !are_all_allowed {
247 return AllowCors::Invalid;
248 }
249 }
250
251 let (filtered, headers) = match cors_allow_headers {
253 AccessControlAllowHeaders::Any => {
254 let headers = requested_headers.map(to_result).collect();
255 (false, headers)
256 }
257 AccessControlAllowHeaders::Only(only) => {
258 let mut filtered = false;
259 let headers: Vec<_> = requested_headers
260 .filter(|header| {
261 let name = &Ascii::new(header.as_ref());
262 filtered = true;
263 only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
264 })
265 .map(to_result)
266 .collect();
267
268 (filtered, headers)
269 }
270 };
271
272 if headers.is_empty() {
273 if filtered {
274 AllowCors::Invalid
275 } else {
276 AllowCors::NotRequired
277 }
278 } else {
279 AllowCors::Ok(headers)
280 }
281}
282
283lazy_static! {
284 static ref ALWAYS_ALLOWED_HEADERS: HashSet<Ascii<&'static str>> = {
286 let mut hs = HashSet::new();
287 hs.insert(Ascii::new("Accept"));
288 hs.insert(Ascii::new("Accept-Language"));
289 hs.insert(Ascii::new("Access-Control-Allow-Origin"));
290 hs.insert(Ascii::new("Access-Control-Request-Headers"));
291 hs.insert(Ascii::new("Content-Language"));
292 hs.insert(Ascii::new("Content-Type"));
293 hs.insert(Ascii::new("Host"));
294 hs.insert(Ascii::new("Origin"));
295 hs.insert(Ascii::new("Content-Length"));
296 hs.insert(Ascii::new("Connection"));
297 hs.insert(Ascii::new("User-Agent"));
298 hs
299 };
300}
301
302#[cfg(test)]
303mod tests {
304 use std::iter;
305
306 use super::*;
307 use crate::hosts::Host;
308
309 #[test]
310 fn should_parse_origin() {
311 use self::OriginProtocol::*;
312
313 assert_eq!(Origin::parse("http://superstring.ch"), Origin::new(Http, "superstring.ch", None));
314 assert_eq!(
315 Origin::parse("http://superstring.ch:8443"),
316 Origin::new(Https, "superstring.ch", Some(8443))
317 );
318 assert_eq!(
319 Origin::parse("chrome-extension://124.0.0.1"),
320 Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None)
321 );
322 assert_eq!(
323 Origin::parse("superstring.ch/somepath"),
324 Origin::new(Http, "superstring.ch", None)
325 );
326 assert_eq!(
327 Origin::parse("127.0.0.1:8545/somepath"),
328 Origin::new(Http, "127.0.0.1", Some(8545))
329 );
330 }
331
332 #[test]
333 fn should_not_allow_partially_matching_origin() {
334 let origin1 = Origin::parse("http://subdomain.somedomain.io");
336 let origin2 = Origin::parse("http://somedomain.io:8080");
337 let host = Host::parse("http://somedomain.io");
338
339 let origin1 = Some(&*origin1);
340 let origin2 = Some(&*origin2);
341 let host = Some(&*host);
342
343 let res1 = get_cors_allow_origin(origin1, host, &Some(vec![]));
345 let res2 = get_cors_allow_origin(origin2, host, &Some(vec![]));
346
347 assert_eq!(res1, AllowCors::Invalid);
349 assert_eq!(res2, AllowCors::Invalid);
350 }
351
352 #[test]
353 fn should_allow_origins_that_matches_hosts() {
354 let origin = Origin::parse("http://127.0.0.1:8080");
356 let host = Host::parse("http://127.0.0.1:8080");
357
358 let origin = Some(&*origin);
359 let host = Some(&*host);
360
361 let res = get_cors_allow_origin(origin, host, &None);
363
364 assert_eq!(res, AllowCors::NotRequired);
366 }
367
368 #[test]
369 fn should_return_none_when_there_are_no_cors_domains_and_no_origin() {
370 let origin = None;
372 let host = None;
373
374 let res = get_cors_allow_origin(origin, host, &None);
376
377 assert_eq!(res, AllowCors::NotRequired);
379 }
380
381 #[test]
382 fn should_return_domain_when_all_are_allowed() {
383 let origin = Some("superstring.ch");
385 let host = None;
386
387 let res = get_cors_allow_origin(origin, host, &None);
389
390 assert_eq!(res, AllowCors::Ok("superstring.ch".into()));
392 }
393
394 #[test]
395 fn should_return_none_for_empty_origin() {
396 let origin = None;
398 let host = None;
399
400 let res = get_cors_allow_origin(
402 origin,
403 host,
404 &Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
405 );
406
407 assert_eq!(res, AllowCors::NotRequired);
409 }
410
411 #[test]
412 fn should_return_none_for_empty_list() {
413 let origin = None;
415 let host = None;
416
417 let res = get_cors_allow_origin(origin, host, &Some(Vec::new()));
419
420 assert_eq!(res, AllowCors::NotRequired);
422 }
423
424 #[test]
425 fn should_return_none_for_not_matching_origin() {
426 let origin = Some("http://superstring.ch".into());
428 let host = None;
429
430 let res = get_cors_allow_origin(
432 origin,
433 host,
434 &Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
435 );
436
437 assert_eq!(res, AllowCors::Invalid);
439 }
440
441 #[test]
442 fn should_return_specific_origin_if_we_allow_any() {
443 let origin = Some("http://superstring.ch".into());
445 let host = None;
446
447 let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Any]));
449
450 assert_eq!(
452 res,
453 AllowCors::Ok(AccessControlAllowOrigin::Value("http://superstring.ch".into()))
454 );
455 }
456
457 #[test]
458 fn should_return_none_if_origin_is_not_defined() {
459 let origin = None;
461 let host = None;
462
463 let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null]));
465
466 assert_eq!(res, AllowCors::NotRequired);
468 }
469
470 #[test]
471 fn should_return_null_if_origin_is_null() {
472 let origin = Some("null".into());
474 let host = None;
475
476 let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null]));
478
479 assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Null));
481 }
482
483 #[test]
484 fn should_return_specific_origin_if_there_is_a_match() {
485 let origin = Some("http://superstring.ch".into());
487 let host = None;
488
489 let res = get_cors_allow_origin(
491 origin,
492 host,
493 &Some(vec![
494 AccessControlAllowOrigin::Value("http://ethereum.org".into()),
495 AccessControlAllowOrigin::Value("http://superstring.ch".into()),
496 ]),
497 );
498
499 assert_eq!(
501 res,
502 AllowCors::Ok(AccessControlAllowOrigin::Value("http://superstring.ch".into()))
503 );
504 }
505
506 #[test]
507 fn should_support_wildcards() {
508 let origin1 = Some("http://superstring.ch".into());
510 let origin2 = Some("http://superstring.cht".into());
511 let origin3 = Some("chrome-extension://test".into());
512 let host = None;
513 let allowed = Some(vec![
514 AccessControlAllowOrigin::Value("http://*.io".into()),
515 AccessControlAllowOrigin::Value("chrome-extension://*".into()),
516 ]);
517
518 let res1 = get_cors_allow_origin(origin1, host, &allowed);
520 let res2 = get_cors_allow_origin(origin2, host, &allowed);
521 let res3 = get_cors_allow_origin(origin3, host, &allowed);
522
523 assert_eq!(
525 res1,
526 AllowCors::Ok(AccessControlAllowOrigin::Value("http://superstring.ch".into()))
527 );
528 assert_eq!(res2, AllowCors::Invalid);
529 assert_eq!(
530 res3,
531 AllowCors::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into()))
532 );
533 }
534
535 #[test]
536 fn should_return_invalid_if_header_not_allowed() {
537 let cors_allow_headers = AccessControlAllowHeaders::Only(vec!["x-allowed".to_owned()]);
539 let headers = vec!["Access-Control-Request-Headers"];
540 let requested = vec!["x-not-allowed"];
541
542 let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| x);
544
545 assert_eq!(res, AllowCors::Invalid);
547 }
548
549 #[test]
550 fn should_return_valid_if_header_allowed() {
551 let allowed = vec!["x-allowed".to_owned()];
553 let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
554 let headers = vec!["Access-Control-Request-Headers"];
555 let requested = vec!["x-allowed"];
556
557 let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| {
559 (*x).to_owned()
560 });
561
562 let allowed = vec!["x-allowed".to_owned()];
564 assert_eq!(res, AllowCors::Ok(allowed));
565 }
566
567 #[test]
568 fn should_return_no_allowed_headers_if_none_in_request() {
569 let allowed = vec!["x-allowed".to_owned()];
571 let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
572 let headers: Vec<String> = vec![];
573
574 let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers, |x| x);
576
577 assert_eq!(res, AllowCors::NotRequired);
579 }
580
581 #[test]
582 fn should_return_not_required_if_any_header_allowed() {
583 let cors_allow_headers = AccessControlAllowHeaders::Any;
585 let headers: Vec<String> = vec![];
586
587 let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers.into(), |x| x);
589
590 assert_eq!(res, AllowCors::NotRequired);
592 }
593
594}