tower_web/middleware/cors/
service.rs

1use super::{Config, CorsResource};
2
3use futures::{Async, Future, Poll};
4use http::{self, HeaderMap, Request, Response, StatusCode};
5use tower_service::Service;
6use util::http::HttpService;
7
8use std::sync::Arc;
9
10/// Decorates a service, providing an implementation of the CORS specification.
11#[derive(Debug)]
12pub struct CorsService<S> {
13    inner: S,
14    config: Arc<Config>,
15}
16
17impl<S> CorsService<S> {
18    pub(super) fn new(inner: S, config: Arc<Config>) -> CorsService<S> {
19        CorsService { inner, config }
20    }
21}
22
23impl<S> Service for CorsService<S>
24where
25    S: HttpService,
26{
27    type Request = Request<S::RequestBody>;
28    type Response = Response<Option<S::ResponseBody>>;
29    type Error = S::Error;
30    type Future = CorsFuture<S::Future>;
31
32    fn poll_ready(&mut self) -> Poll<(), Self::Error> {
33        self.inner.poll_http_ready()
34    }
35
36    fn call(&mut self, request: Self::Request) -> Self::Future {
37        let inner = match self.config.process_request(&request) {
38            Ok(CorsResource::Preflight(headers)) => CorsFutureInner::Handled(Some(headers)),
39            Ok(CorsResource::Simple(headers)) => {
40                CorsFutureInner::Simple(self.inner.call_http(request), Some(headers))
41            }
42            Err(e) => {
43                debug!("CORS request to {} is denied: {:?}", request.uri(), e);
44                CorsFutureInner::Handled(None)
45            }
46        };
47
48        CorsFuture(inner)
49    }
50}
51
52#[derive(Debug)]
53pub struct CorsFuture<F>(CorsFutureInner<F>);
54
55impl<F, ResponseBody> Future for CorsFuture<F>
56where
57    F: Future<Item = http::Response<ResponseBody>>,
58{
59    type Item = http::Response<Option<ResponseBody>>;
60    type Error = F::Error;
61
62    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
63        self.0.poll()
64    }
65}
66
67#[derive(Debug)]
68enum CorsFutureInner<F> {
69    Simple(F, Option<HeaderMap>),
70    Handled(Option<HeaderMap>),
71}
72
73impl<F, ResponseBody> Future for CorsFutureInner<F>
74where
75    F: Future<Item = http::Response<ResponseBody>>,
76{
77    type Item = http::Response<Option<ResponseBody>>;
78    type Error = F::Error;
79
80    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
81        use self::CorsFutureInner::*;
82
83        match self {
84            Simple(f, headers) => {
85                let mut response = try_ready!(f.poll());
86                let headers = headers.take().expect("poll called twice");
87                response.headers_mut().extend(headers);
88                Ok(Async::Ready(response.map(Some)))
89            }
90            Handled(headers) => {
91                let mut response = http::Response::new(None);
92                *response.status_mut() = StatusCode::FORBIDDEN;
93
94                if let Some(headers) = headers.take() {
95                    *response.status_mut() = StatusCode::NO_CONTENT;
96                    *response.headers_mut() = headers;
97                }
98
99                Ok(Async::Ready(response))
100            }
101        }
102    }
103}
104
105#[cfg(test)]
106mod test {
107    use futures::future::{self, FutureResult};
108    use http::{
109        header::{self, HeaderValue},
110        Method,
111    };
112
113    use middleware::cors::{AllowedOrigins, CorsBuilder};
114    use util::buf_stream::{self, Empty};
115
116    use super::*;
117
118    type TestError = Box<::std::error::Error>;
119    type TestResult = ::std::result::Result<(), TestError>;
120
121    type DontCare = Empty<Option<[u8; 1]>, ()>;
122
123    #[derive(Debug, Default)]
124    struct MockService {
125        poll_ready_count: usize,
126        requests: Vec<http::Request<DontCare>>,
127    }
128
129    impl Service for MockService {
130        type Request = http::Request<DontCare>;
131        type Response = http::Response<DontCare>;
132        type Error = TestError;
133        type Future = FutureResult<Self::Response, Self::Error>;
134
135        fn poll_ready(&mut self) -> Poll<(), Self::Error> {
136            self.poll_ready_count += 1;
137            Ok(Async::Ready(()))
138        }
139
140        fn call(&mut self, request: Self::Request) -> Self::Future {
141            self.requests.push(request);
142            future::ok(http::Response::new(buf_stream::empty()))
143        }
144    }
145
146    #[test]
147    fn polls_the_inner_service() -> TestResult {
148        let cfg = Arc::new(CorsBuilder::new().into_config());
149        let mut service = CorsService::new(MockService::default(), cfg);
150
151        service.poll_ready()?;
152        assert_eq!(service.inner.poll_ready_count, 1);
153
154        Ok(())
155    }
156
157    #[test]
158    fn forwards_the_request_when_not_cors() -> TestResult {
159        let cfg = Arc::new(CorsBuilder::new().into_config());
160        let mut service = CorsService::new(MockService::default(), cfg);
161        let req = http::Request::builder().body(buf_stream::empty())?;
162
163        service.call(req);
164        assert_eq!(service.inner.requests.len(), 1);
165
166        Ok(())
167    }
168
169    #[test]
170    fn does_not_forward_the_request_when_preflight() -> TestResult {
171        let cfg = Arc::new(CorsBuilder::new().into_config());
172        let mut service = CorsService::new(MockService::default(), cfg);
173        let req = http::Request::builder()
174            .method(Method::OPTIONS)
175            .header(
176                header::ORIGIN,
177                HeaderValue::from_static("http://test.example"),
178            ).header(
179                header::ACCESS_CONTROL_REQUEST_METHOD,
180                HeaderValue::from_static("POST"),
181            ).body(buf_stream::empty())?;
182
183        service.call(req);
184        assert_eq!(service.inner.requests.len(), 0);
185
186        Ok(())
187    }
188
189    #[test]
190    fn responds_with_error_when_bad_cors() -> TestResult {
191        let cfg = Arc::new(CorsBuilder::new().into_config());
192        let mut service = CorsService::new(MockService::default(), cfg);
193        // Disallowed "Origin" header
194        let req = http::Request::builder()
195            .header(
196                header::ORIGIN,
197                HeaderValue::from_static("http://not-me.example"),
198            ).body(buf_stream::empty())?;
199
200        let resp = service.call(req).wait()?;
201        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
202
203        Ok(())
204    }
205
206    #[test]
207    fn responds_with_no_content_when_ok_preflight() -> TestResult {
208        let cfg = CorsBuilder::new()
209            .allow_origins(AllowedOrigins::Any { allow_null: false })
210            .allow_methods(vec![Method::POST])
211            .into_config();
212
213        let mut service = CorsService::new(MockService::default(), Arc::new(cfg));
214        let req = http::Request::builder()
215            .method(Method::OPTIONS)
216            .header(
217                header::ACCESS_CONTROL_REQUEST_METHOD,
218                HeaderValue::from_static("POST"),
219            ).header(
220                header::ORIGIN,
221                HeaderValue::from_bytes(b"http://test.example")?,
222            ).body(buf_stream::empty())?;
223
224        let resp = service.call(req).wait()?;
225        assert_eq!(resp.status(), StatusCode::NO_CONTENT);
226
227        Ok(())
228    }
229}