susydev_jsonrpc_server_utils/
cors.rs

1//! CORS handling utility functions
2use unicase;
3
4pub use self::unicase::Ascii;
5use crate::hosts::{Host, Port};
6use crate::matcher::{Matcher, Pattern};
7use std::collections::HashSet;
8use std::{fmt, ops};
9
10/// Origin Protocol
11#[derive(Clone, Hash, Debug, PartialEq, Eq)]
12pub enum OriginProtocol {
13	/// Http protocol
14	Http,
15	/// Https protocol
16	Https,
17	/// Custom protocol
18	Custom(String),
19}
20
21/// Request Origin
22#[derive(Clone, PartialEq, Eq, Debug, Hash)]
23pub struct Origin {
24	protocol: OriginProtocol,
25	host: Host,
26	as_string: String,
27	matcher: Matcher,
28}
29
30impl<T: AsRef<str>> From<T> for Origin {
31	fn from(string: T) -> Self {
32		Origin::parse(string.as_ref())
33	}
34}
35
36impl Origin {
37	fn with_host(protocol: OriginProtocol, host: Host) -> Self {
38		let string = Self::to_string(&protocol, &host);
39		let matcher = Matcher::new(&string);
40
41		Origin {
42			protocol,
43			host,
44			as_string: string,
45			matcher,
46		}
47	}
48
49	/// Creates new origin given protocol, hostname and port parts.
50	/// Pre-processes input data if necessary.
51	pub fn new<T: Into<Port>>(protocol: OriginProtocol, host: &str, port: T) -> Self {
52		Self::with_host(protocol, Host::new(host, port))
53	}
54
55	/// Attempts to parse given string as a `Origin`.
56	/// NOTE: This method always succeeds and falls back to sensible defaults.
57	pub fn parse(data: &str) -> Self {
58		let mut it = data.split("://");
59		let proto = it.next().expect("split always returns non-empty iterator.");
60		let hostname = it.next();
61
62		let (proto, hostname) = match hostname {
63			None => (None, proto),
64			Some(hostname) => (Some(proto), hostname),
65		};
66
67		let proto = proto.map(str::to_lowercase);
68		let hostname = Host::parse(hostname);
69
70		let protocol = match proto {
71			None => OriginProtocol::Http,
72			Some(ref p) if p == "http" => OriginProtocol::Http,
73			Some(ref p) if p == "https" => OriginProtocol::Https,
74			Some(other) => OriginProtocol::Custom(other),
75		};
76
77		Origin::with_host(protocol, hostname)
78	}
79
80	fn to_string(protocol: &OriginProtocol, host: &Host) -> String {
81		format!(
82			"{}://{}",
83			match *protocol {
84				OriginProtocol::Http => "http",
85				OriginProtocol::Https => "https",
86				OriginProtocol::Custom(ref protocol) => protocol,
87			},
88			&**host,
89		)
90	}
91}
92
93impl Pattern for Origin {
94	fn matches<T: AsRef<str>>(&self, other: T) -> bool {
95		self.matcher.matches(other)
96	}
97}
98
99impl ops::Deref for Origin {
100	type Target = str;
101	fn deref(&self) -> &Self::Target {
102		&self.as_string
103	}
104}
105
106/// Origins allowed to access
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum AccessControlAllowOrigin {
109	/// Specific hostname
110	Value(Origin),
111	/// null-origin (file:///, sandboxed iframe)
112	Null,
113	/// Any non-null origin
114	Any,
115}
116
117impl fmt::Display for AccessControlAllowOrigin {
118	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
119		write!(
120			f,
121			"{}",
122			match *self {
123				AccessControlAllowOrigin::Any => "*",
124				AccessControlAllowOrigin::Null => "null",
125				AccessControlAllowOrigin::Value(ref val) => val,
126			}
127		)
128	}
129}
130
131impl<T: Into<String>> From<T> for AccessControlAllowOrigin {
132	fn from(s: T) -> AccessControlAllowOrigin {
133		match s.into().as_str() {
134			"all" | "*" | "any" => AccessControlAllowOrigin::Any,
135			"null" => AccessControlAllowOrigin::Null,
136			origin => AccessControlAllowOrigin::Value(origin.into()),
137		}
138	}
139}
140
141/// Headers allowed to access
142#[derive(Debug, Clone, PartialEq)]
143pub enum AccessControlAllowHeaders {
144	/// Specific headers
145	Only(Vec<String>),
146	/// Any header
147	Any,
148}
149
150/// CORS response headers
151#[derive(Debug, Clone, PartialEq, Eq)]
152pub enum AllowCors<T> {
153	/// CORS header was not required. Origin is not present in the request.
154	NotRequired,
155	/// CORS header is not returned, Origin is not allowed to access the resource.
156	Invalid,
157	/// CORS header to include in the response. Origin is allowed to access the resource.
158	Ok(T),
159}
160
161impl<T> AllowCors<T> {
162	/// Maps `Ok` variant of `AllowCors`.
163	pub fn map<F, O>(self, f: F) -> AllowCors<O>
164	where
165		F: FnOnce(T) -> O,
166	{
167		use self::AllowCors::*;
168
169		match self {
170			NotRequired => NotRequired,
171			Invalid => Invalid,
172			Ok(val) => Ok(f(val)),
173		}
174	}
175}
176
177impl<T> Into<Option<T>> for AllowCors<T> {
178	fn into(self) -> Option<T> {
179		use self::AllowCors::*;
180
181		match self {
182			NotRequired | Invalid => None,
183			Ok(header) => Some(header),
184		}
185	}
186}
187
188/// Returns correct CORS header (if any) given list of allowed origins and current origin.
189pub fn get_cors_allow_origin(
190	origin: Option<&str>,
191	host: Option<&str>,
192	allowed: &Option<Vec<AccessControlAllowOrigin>>,
193) -> AllowCors<AccessControlAllowOrigin> {
194	match origin {
195		None => AllowCors::NotRequired,
196		Some(ref origin) => {
197			if let Some(host) = host {
198				// Request initiated from the same server.
199				if origin.ends_with(host) {
200					// Additional check
201					let origin = Origin::parse(origin);
202					if &*origin.host == host {
203						return AllowCors::NotRequired;
204					}
205				}
206			}
207
208			match allowed.as_ref() {
209				None if *origin == "null" => AllowCors::Ok(AccessControlAllowOrigin::Null),
210				None => AllowCors::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))),
211				Some(ref allowed) if *origin == "null" => allowed
212					.iter()
213					.find(|cors| **cors == AccessControlAllowOrigin::Null)
214					.cloned()
215					.map(AllowCors::Ok)
216					.unwrap_or(AllowCors::Invalid),
217				Some(ref allowed) => allowed
218					.iter()
219					.find(|cors| match **cors {
220						AccessControlAllowOrigin::Any => true,
221						AccessControlAllowOrigin::Value(ref val) if val.matches(origin) => true,
222						_ => false,
223					})
224					.map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin)))
225					.map(AllowCors::Ok)
226					.unwrap_or(AllowCors::Invalid),
227			}
228		}
229	}
230}
231
232/// Validates if the `AccessControlAllowedHeaders` in the request are allowed.
233pub fn get_cors_allow_headers<T: AsRef<str>, O, F: Fn(T) -> O>(
234	mut headers: impl Iterator<Item = T>,
235	requested_headers: impl Iterator<Item = T>,
236	cors_allow_headers: &AccessControlAllowHeaders,
237	to_result: F,
238) -> AllowCors<Vec<O>> {
239	// Check if the header fields which were sent in the request are allowed
240	if let AccessControlAllowHeaders::Only(only) = cors_allow_headers {
241		let are_all_allowed = headers.all(|header| {
242			let name = &Ascii::new(header.as_ref());
243			only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
244		});
245
246		if !are_all_allowed {
247			return AllowCors::Invalid;
248		}
249	}
250
251	// Check if `AccessControlRequestHeaders` contains fields which were allowed
252	let (filtered, headers) = match cors_allow_headers {
253		AccessControlAllowHeaders::Any => {
254			let headers = requested_headers.map(to_result).collect();
255			(false, headers)
256		}
257		AccessControlAllowHeaders::Only(only) => {
258			let mut filtered = false;
259			let headers: Vec<_> = requested_headers
260				.filter(|header| {
261					let name = &Ascii::new(header.as_ref());
262					filtered = true;
263					only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name)
264				})
265				.map(to_result)
266				.collect();
267
268			(filtered, headers)
269		}
270	};
271
272	if headers.is_empty() {
273		if filtered {
274			AllowCors::Invalid
275		} else {
276			AllowCors::NotRequired
277		}
278	} else {
279		AllowCors::Ok(headers)
280	}
281}
282
283lazy_static! {
284	/// Returns headers which are always allowed.
285	static ref ALWAYS_ALLOWED_HEADERS: HashSet<Ascii<&'static str>> = {
286		let mut hs = HashSet::new();
287		hs.insert(Ascii::new("Accept"));
288		hs.insert(Ascii::new("Accept-Language"));
289		hs.insert(Ascii::new("Access-Control-Allow-Origin"));
290		hs.insert(Ascii::new("Access-Control-Request-Headers"));
291		hs.insert(Ascii::new("Content-Language"));
292		hs.insert(Ascii::new("Content-Type"));
293		hs.insert(Ascii::new("Host"));
294		hs.insert(Ascii::new("Origin"));
295		hs.insert(Ascii::new("Content-Length"));
296		hs.insert(Ascii::new("Connection"));
297		hs.insert(Ascii::new("User-Agent"));
298		hs
299	};
300}
301
302#[cfg(test)]
303mod tests {
304	use std::iter;
305
306	use super::*;
307	use crate::hosts::Host;
308
309	#[test]
310	fn should_parse_origin() {
311		use self::OriginProtocol::*;
312
313		assert_eq!(Origin::parse("http://superstring.ch"), Origin::new(Http, "superstring.ch", None));
314		assert_eq!(
315			Origin::parse("http://superstring.ch:8443"),
316			Origin::new(Https, "superstring.ch", Some(8443))
317		);
318		assert_eq!(
319			Origin::parse("chrome-extension://124.0.0.1"),
320			Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None)
321		);
322		assert_eq!(
323			Origin::parse("superstring.ch/somepath"),
324			Origin::new(Http, "superstring.ch", None)
325		);
326		assert_eq!(
327			Origin::parse("127.0.0.1:8545/somepath"),
328			Origin::new(Http, "127.0.0.1", Some(8545))
329		);
330	}
331
332	#[test]
333	fn should_not_allow_partially_matching_origin() {
334		// given
335		let origin1 = Origin::parse("http://subdomain.somedomain.io");
336		let origin2 = Origin::parse("http://somedomain.io:8080");
337		let host = Host::parse("http://somedomain.io");
338
339		let origin1 = Some(&*origin1);
340		let origin2 = Some(&*origin2);
341		let host = Some(&*host);
342
343		// when
344		let res1 = get_cors_allow_origin(origin1, host, &Some(vec![]));
345		let res2 = get_cors_allow_origin(origin2, host, &Some(vec![]));
346
347		// then
348		assert_eq!(res1, AllowCors::Invalid);
349		assert_eq!(res2, AllowCors::Invalid);
350	}
351
352	#[test]
353	fn should_allow_origins_that_matches_hosts() {
354		// given
355		let origin = Origin::parse("http://127.0.0.1:8080");
356		let host = Host::parse("http://127.0.0.1:8080");
357
358		let origin = Some(&*origin);
359		let host = Some(&*host);
360
361		// when
362		let res = get_cors_allow_origin(origin, host, &None);
363
364		// then
365		assert_eq!(res, AllowCors::NotRequired);
366	}
367
368	#[test]
369	fn should_return_none_when_there_are_no_cors_domains_and_no_origin() {
370		// given
371		let origin = None;
372		let host = None;
373
374		// when
375		let res = get_cors_allow_origin(origin, host, &None);
376
377		// then
378		assert_eq!(res, AllowCors::NotRequired);
379	}
380
381	#[test]
382	fn should_return_domain_when_all_are_allowed() {
383		// given
384		let origin = Some("superstring.ch");
385		let host = None;
386
387		// when
388		let res = get_cors_allow_origin(origin, host, &None);
389
390		// then
391		assert_eq!(res, AllowCors::Ok("superstring.ch".into()));
392	}
393
394	#[test]
395	fn should_return_none_for_empty_origin() {
396		// given
397		let origin = None;
398		let host = None;
399
400		// when
401		let res = get_cors_allow_origin(
402			origin,
403			host,
404			&Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
405		);
406
407		// then
408		assert_eq!(res, AllowCors::NotRequired);
409	}
410
411	#[test]
412	fn should_return_none_for_empty_list() {
413		// given
414		let origin = None;
415		let host = None;
416
417		// when
418		let res = get_cors_allow_origin(origin, host, &Some(Vec::new()));
419
420		// then
421		assert_eq!(res, AllowCors::NotRequired);
422	}
423
424	#[test]
425	fn should_return_none_for_not_matching_origin() {
426		// given
427		let origin = Some("http://superstring.ch".into());
428		let host = None;
429
430		// when
431		let res = get_cors_allow_origin(
432			origin,
433			host,
434			&Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]),
435		);
436
437		// then
438		assert_eq!(res, AllowCors::Invalid);
439	}
440
441	#[test]
442	fn should_return_specific_origin_if_we_allow_any() {
443		// given
444		let origin = Some("http://superstring.ch".into());
445		let host = None;
446
447		// when
448		let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Any]));
449
450		// then
451		assert_eq!(
452			res,
453			AllowCors::Ok(AccessControlAllowOrigin::Value("http://superstring.ch".into()))
454		);
455	}
456
457	#[test]
458	fn should_return_none_if_origin_is_not_defined() {
459		// given
460		let origin = None;
461		let host = None;
462
463		// when
464		let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null]));
465
466		// then
467		assert_eq!(res, AllowCors::NotRequired);
468	}
469
470	#[test]
471	fn should_return_null_if_origin_is_null() {
472		// given
473		let origin = Some("null".into());
474		let host = None;
475
476		// when
477		let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null]));
478
479		// then
480		assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Null));
481	}
482
483	#[test]
484	fn should_return_specific_origin_if_there_is_a_match() {
485		// given
486		let origin = Some("http://superstring.ch".into());
487		let host = None;
488
489		// when
490		let res = get_cors_allow_origin(
491			origin,
492			host,
493			&Some(vec![
494				AccessControlAllowOrigin::Value("http://ethereum.org".into()),
495				AccessControlAllowOrigin::Value("http://superstring.ch".into()),
496			]),
497		);
498
499		// then
500		assert_eq!(
501			res,
502			AllowCors::Ok(AccessControlAllowOrigin::Value("http://superstring.ch".into()))
503		);
504	}
505
506	#[test]
507	fn should_support_wildcards() {
508		// given
509		let origin1 = Some("http://superstring.ch".into());
510		let origin2 = Some("http://superstring.cht".into());
511		let origin3 = Some("chrome-extension://test".into());
512		let host = None;
513		let allowed = Some(vec![
514			AccessControlAllowOrigin::Value("http://*.io".into()),
515			AccessControlAllowOrigin::Value("chrome-extension://*".into()),
516		]);
517
518		// when
519		let res1 = get_cors_allow_origin(origin1, host, &allowed);
520		let res2 = get_cors_allow_origin(origin2, host, &allowed);
521		let res3 = get_cors_allow_origin(origin3, host, &allowed);
522
523		// then
524		assert_eq!(
525			res1,
526			AllowCors::Ok(AccessControlAllowOrigin::Value("http://superstring.ch".into()))
527		);
528		assert_eq!(res2, AllowCors::Invalid);
529		assert_eq!(
530			res3,
531			AllowCors::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into()))
532		);
533	}
534
535	#[test]
536	fn should_return_invalid_if_header_not_allowed() {
537		// given
538		let cors_allow_headers = AccessControlAllowHeaders::Only(vec!["x-allowed".to_owned()]);
539		let headers = vec!["Access-Control-Request-Headers"];
540		let requested = vec!["x-not-allowed"];
541
542		// when
543		let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| x);
544
545		// then
546		assert_eq!(res, AllowCors::Invalid);
547	}
548
549	#[test]
550	fn should_return_valid_if_header_allowed() {
551		// given
552		let allowed = vec!["x-allowed".to_owned()];
553		let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
554		let headers = vec!["Access-Control-Request-Headers"];
555		let requested = vec!["x-allowed"];
556
557		// when
558		let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| {
559			(*x).to_owned()
560		});
561
562		// then
563		let allowed = vec!["x-allowed".to_owned()];
564		assert_eq!(res, AllowCors::Ok(allowed));
565	}
566
567	#[test]
568	fn should_return_no_allowed_headers_if_none_in_request() {
569		// given
570		let allowed = vec!["x-allowed".to_owned()];
571		let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone());
572		let headers: Vec<String> = vec![];
573
574		// when
575		let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers, |x| x);
576
577		// then
578		assert_eq!(res, AllowCors::NotRequired);
579	}
580
581	#[test]
582	fn should_return_not_required_if_any_header_allowed() {
583		// given
584		let cors_allow_headers = AccessControlAllowHeaders::Any;
585		let headers: Vec<String> = vec![];
586
587		// when
588		let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers.into(), |x| x);
589
590		// then
591		assert_eq!(res, AllowCors::NotRequired);
592	}
593
594}