tower_web/middleware/cors/
service.rs1use super::{Config, CorsResource};
2
3use futures::{Async, Future, Poll};
4use http::{self, HeaderMap, Request, Response, StatusCode};
5use tower_service::Service;
6use util::http::HttpService;
7
8use std::sync::Arc;
9
10#[derive(Debug)]
12pub struct CorsService<S> {
13 inner: S,
14 config: Arc<Config>,
15}
16
17impl<S> CorsService<S> {
18 pub(super) fn new(inner: S, config: Arc<Config>) -> CorsService<S> {
19 CorsService { inner, config }
20 }
21}
22
23impl<S> Service for CorsService<S>
24where
25 S: HttpService,
26{
27 type Request = Request<S::RequestBody>;
28 type Response = Response<Option<S::ResponseBody>>;
29 type Error = S::Error;
30 type Future = CorsFuture<S::Future>;
31
32 fn poll_ready(&mut self) -> Poll<(), Self::Error> {
33 self.inner.poll_http_ready()
34 }
35
36 fn call(&mut self, request: Self::Request) -> Self::Future {
37 let inner = match self.config.process_request(&request) {
38 Ok(CorsResource::Preflight(headers)) => CorsFutureInner::Handled(Some(headers)),
39 Ok(CorsResource::Simple(headers)) => {
40 CorsFutureInner::Simple(self.inner.call_http(request), Some(headers))
41 }
42 Err(e) => {
43 debug!("CORS request to {} is denied: {:?}", request.uri(), e);
44 CorsFutureInner::Handled(None)
45 }
46 };
47
48 CorsFuture(inner)
49 }
50}
51
52#[derive(Debug)]
53pub struct CorsFuture<F>(CorsFutureInner<F>);
54
55impl<F, ResponseBody> Future for CorsFuture<F>
56where
57 F: Future<Item = http::Response<ResponseBody>>,
58{
59 type Item = http::Response<Option<ResponseBody>>;
60 type Error = F::Error;
61
62 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
63 self.0.poll()
64 }
65}
66
67#[derive(Debug)]
68enum CorsFutureInner<F> {
69 Simple(F, Option<HeaderMap>),
70 Handled(Option<HeaderMap>),
71}
72
73impl<F, ResponseBody> Future for CorsFutureInner<F>
74where
75 F: Future<Item = http::Response<ResponseBody>>,
76{
77 type Item = http::Response<Option<ResponseBody>>;
78 type Error = F::Error;
79
80 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
81 use self::CorsFutureInner::*;
82
83 match self {
84 Simple(f, headers) => {
85 let mut response = try_ready!(f.poll());
86 let headers = headers.take().expect("poll called twice");
87 response.headers_mut().extend(headers);
88 Ok(Async::Ready(response.map(Some)))
89 }
90 Handled(headers) => {
91 let mut response = http::Response::new(None);
92 *response.status_mut() = StatusCode::FORBIDDEN;
93
94 if let Some(headers) = headers.take() {
95 *response.status_mut() = StatusCode::NO_CONTENT;
96 *response.headers_mut() = headers;
97 }
98
99 Ok(Async::Ready(response))
100 }
101 }
102 }
103}
104
105#[cfg(test)]
106mod test {
107 use futures::future::{self, FutureResult};
108 use http::{
109 header::{self, HeaderValue},
110 Method,
111 };
112
113 use middleware::cors::{AllowedOrigins, CorsBuilder};
114 use util::buf_stream::{self, Empty};
115
116 use super::*;
117
118 type TestError = Box<::std::error::Error>;
119 type TestResult = ::std::result::Result<(), TestError>;
120
121 type DontCare = Empty<Option<[u8; 1]>, ()>;
122
123 #[derive(Debug, Default)]
124 struct MockService {
125 poll_ready_count: usize,
126 requests: Vec<http::Request<DontCare>>,
127 }
128
129 impl Service for MockService {
130 type Request = http::Request<DontCare>;
131 type Response = http::Response<DontCare>;
132 type Error = TestError;
133 type Future = FutureResult<Self::Response, Self::Error>;
134
135 fn poll_ready(&mut self) -> Poll<(), Self::Error> {
136 self.poll_ready_count += 1;
137 Ok(Async::Ready(()))
138 }
139
140 fn call(&mut self, request: Self::Request) -> Self::Future {
141 self.requests.push(request);
142 future::ok(http::Response::new(buf_stream::empty()))
143 }
144 }
145
146 #[test]
147 fn polls_the_inner_service() -> TestResult {
148 let cfg = Arc::new(CorsBuilder::new().into_config());
149 let mut service = CorsService::new(MockService::default(), cfg);
150
151 service.poll_ready()?;
152 assert_eq!(service.inner.poll_ready_count, 1);
153
154 Ok(())
155 }
156
157 #[test]
158 fn forwards_the_request_when_not_cors() -> TestResult {
159 let cfg = Arc::new(CorsBuilder::new().into_config());
160 let mut service = CorsService::new(MockService::default(), cfg);
161 let req = http::Request::builder().body(buf_stream::empty())?;
162
163 service.call(req);
164 assert_eq!(service.inner.requests.len(), 1);
165
166 Ok(())
167 }
168
169 #[test]
170 fn does_not_forward_the_request_when_preflight() -> TestResult {
171 let cfg = Arc::new(CorsBuilder::new().into_config());
172 let mut service = CorsService::new(MockService::default(), cfg);
173 let req = http::Request::builder()
174 .method(Method::OPTIONS)
175 .header(
176 header::ORIGIN,
177 HeaderValue::from_static("http://test.example"),
178 ).header(
179 header::ACCESS_CONTROL_REQUEST_METHOD,
180 HeaderValue::from_static("POST"),
181 ).body(buf_stream::empty())?;
182
183 service.call(req);
184 assert_eq!(service.inner.requests.len(), 0);
185
186 Ok(())
187 }
188
189 #[test]
190 fn responds_with_error_when_bad_cors() -> TestResult {
191 let cfg = Arc::new(CorsBuilder::new().into_config());
192 let mut service = CorsService::new(MockService::default(), cfg);
193 let req = http::Request::builder()
195 .header(
196 header::ORIGIN,
197 HeaderValue::from_static("http://not-me.example"),
198 ).body(buf_stream::empty())?;
199
200 let resp = service.call(req).wait()?;
201 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
202
203 Ok(())
204 }
205
206 #[test]
207 fn responds_with_no_content_when_ok_preflight() -> TestResult {
208 let cfg = CorsBuilder::new()
209 .allow_origins(AllowedOrigins::Any { allow_null: false })
210 .allow_methods(vec![Method::POST])
211 .into_config();
212
213 let mut service = CorsService::new(MockService::default(), Arc::new(cfg));
214 let req = http::Request::builder()
215 .method(Method::OPTIONS)
216 .header(
217 header::ACCESS_CONTROL_REQUEST_METHOD,
218 HeaderValue::from_static("POST"),
219 ).header(
220 header::ORIGIN,
221 HeaderValue::from_bytes(b"http://test.example")?,
222 ).body(buf_stream::empty())?;
223
224 let resp = service.call(req).wait()?;
225 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
226
227 Ok(())
228 }
229}