use futures_util::io::{AsyncRead, AsyncWrite};
use headers::{
AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMapExt,
Origin,
};
use http::{
header::{self, HeaderMap, HeaderName, HeaderValue},
Method, StatusCode,
};
use std::collections::HashSet;
use std::convert::TryFrom;
use crate::{
server::{
glitch::{Glitch, Result},
ResponseWriter,
},
Request,
};
pub struct CorsBuilder {
pub credentials: bool,
pub allowed_headers: HashSet<HeaderName>,
pub exposed_headers: HashSet<HeaderName>,
pub max_age: Option<u64>,
pub methods: HashSet<http::Method>,
pub origins: Option<HashSet<HeaderValue>>,
}
impl CorsBuilder {
pub fn allow_credentials(mut self, allow: bool) -> Self {
self.credentials = allow;
self
}
pub fn allow_method<M>(mut self, method: M) -> Self
where
http::Method: TryFrom<M>,
{
let method = match TryFrom::try_from(method) {
Ok(m) => m,
_ => panic!("illegal Method"),
};
self.methods.insert(method);
self
}
pub fn allow_methods<I>(mut self, methods: I) -> Self
where
I: IntoIterator,
http::Method: TryFrom<I::Item>,
{
let iter = methods.into_iter().map(|m| match TryFrom::try_from(m) {
Ok(m) => m,
_ => panic!("illegal Method"),
});
self.methods.extend(iter);
self
}
pub fn allow_header<H>(mut self, header: H) -> Self
where
HeaderName: TryFrom<H>,
{
let header = match TryFrom::try_from(header) {
Ok(m) => m,
_ => panic!("illegal Header"),
};
self.allowed_headers.insert(header);
self
}
pub fn allow_headers<I>(mut self, headers: I) -> Self
where
I: IntoIterator,
HeaderName: TryFrom<I::Item>,
{
let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
Ok(h) => h,
_ => panic!("illegal Header"),
});
self.allowed_headers.extend(iter);
self
}
pub fn expose_header<H>(mut self, header: H) -> Self
where
HeaderName: TryFrom<H>,
{
let header = match TryFrom::try_from(header) {
Ok(m) => m,
_ => panic!("illegal Header"),
};
self.exposed_headers.insert(header);
self
}
pub fn expose_headers<I>(mut self, headers: I) -> Self
where
I: IntoIterator,
HeaderName: TryFrom<I::Item>,
{
let iter = headers.into_iter().map(|h| match TryFrom::try_from(h) {
Ok(h) => h,
_ => panic!("illegal Header"),
});
self.exposed_headers.extend(iter);
self
}
pub fn allow_any_origin(mut self) -> Self {
self.origins = None;
self
}
pub fn allow_origin(self, origin: impl IntoOrigin) -> Self {
self.allow_origins(Some(origin))
}
pub fn allow_origins<I>(mut self, origins: I) -> Self
where
I: IntoIterator,
I::Item: IntoOrigin,
{
let iter = origins
.into_iter()
.map(IntoOrigin::into_origin)
.map(|origin| {
origin
.to_string()
.parse()
.expect("Origin is always a valid HeaderValue")
});
self.origins.get_or_insert_with(HashSet::new).extend(iter);
self
}
pub fn max_age(mut self, seconds: u64) -> Self {
self.max_age = Some(seconds);
self
}
pub fn finish(self) -> Cors {
let exposed_headers = if self.exposed_headers.is_empty() {
None
} else {
Some(self.exposed_headers.into_iter().collect())
};
Cors {
credentials: self.credentials,
allowed_headers: self.allowed_headers.iter().cloned().collect(),
allowed_headers_set: self.allowed_headers,
exposed_headers,
max_age: self.max_age,
methods: self.methods.iter().cloned().collect(),
methods_set: self.methods,
origins: self.origins,
}
}
}
#[derive(Clone)]
pub struct Cors {
credentials: bool,
allowed_headers_set: HashSet<HeaderName>,
allowed_headers: AccessControlAllowHeaders,
exposed_headers: Option<AccessControlExposeHeaders>,
max_age: Option<u64>,
methods_set: HashSet<http::Method>,
methods: AccessControlAllowMethods,
origins: Option<HashSet<HeaderValue>>,
}
impl Cors {
pub fn build() -> CorsBuilder {
CorsBuilder {
credentials: false,
allowed_headers: HashSet::new(),
exposed_headers: HashSet::new(),
max_age: None,
methods: HashSet::new(),
origins: None,
}
}
pub fn validate<W>(&self, req: &Request, resp_wtr: &mut ResponseWriter<W>) -> Result<()>
where
W: AsyncRead + AsyncWrite + Clone + Send + Sync + Unpin + 'static,
{
let req_method = req.method();
let req_origin = req.headers().get(header::ORIGIN);
match (req_method, req_origin) {
(&Method::OPTIONS, Some(origin)) => {
if !self.is_origin_allowed(origin) {
return Err(Glitch::bad_request());
}
let headers = req.headers();
if let Some(req_method) = headers.get(header::ACCESS_CONTROL_REQUEST_METHOD) {
if !self.is_method_allowed(req_method) {
return Err(Glitch::bad_request());
}
} else {
println!("hit");
return Err(Glitch::bad_request());
}
if let Some(req_headers) = headers.get(header::ACCESS_CONTROL_REQUEST_HEADERS) {
let headers = match req_headers.to_str() {
Ok(h) => h,
Err(_) => return Err(Glitch::bad_request()),
};
for header in headers.split(',') {
if !self.is_header_allowed(header) {
return Err(Glitch::bad_request());
}
}
}
let mut resp = Glitch::new();
let mut headers = HeaderMap::new();
self.append_preflight_headers(&mut headers);
headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
resp.status = Some(StatusCode::OK);
resp.headers = Some(headers);
Err(resp)
}
(_, Some(origin)) => {
if self.is_origin_allowed(origin) {
let mut headers = resp_wtr.response_mut().headers_mut();
self.append_common_headers(&mut headers);
resp_wtr.insert_header(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone());
return Ok(());
}
Err(Glitch::bad_request())
}
(_, _) => {
Ok(())
}
}
}
fn is_method_allowed(&self, header: &HeaderValue) -> bool {
http::Method::from_bytes(header.as_bytes())
.map(|method| self.methods_set.contains(&method))
.unwrap_or(false)
}
fn is_header_allowed(&self, header: &str) -> bool {
HeaderName::from_bytes(header.as_bytes())
.map(|header| self.allowed_headers_set.contains(&header))
.unwrap_or(false)
}
fn is_origin_allowed(&self, origin: &HeaderValue) -> bool {
if let Some(ref allowed) = self.origins {
allowed.contains(origin)
} else {
true
}
}
fn append_preflight_headers(&self, headers: &mut HeaderMap) {
self.append_common_headers(headers);
headers.typed_insert(self.allowed_headers.clone());
headers.typed_insert(self.methods.clone());
if let Some(max_age) = self.max_age {
headers.insert(header::ACCESS_CONTROL_MAX_AGE, max_age.into());
}
}
fn append_common_headers(&self, headers: &mut HeaderMap) {
if self.credentials {
headers.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
}
if let Some(expose_headers_header) = &self.exposed_headers {
headers.typed_insert(expose_headers_header.clone())
}
}
}
pub trait IntoOrigin {
fn into_origin(self) -> Origin;
}
impl<'a> IntoOrigin for &'a str {
fn into_origin(self) -> Origin {
let mut parts = self.splitn(2, "://");
let scheme = parts.next().expect("missing scheme");
let rest = parts.next().expect("missing scheme");
Origin::try_from_parts(scheme, rest, None).expect("invalid Origin")
}
}