1use std::sync::Arc;
133
134use futures::future::{self, Either, Ready};
135use http::StatusCode;
136use policy::Policy;
137use tower::{Layer, Service};
138
139pub use authorizer::*;
140pub use policy::PolicyBuilder;
141pub use reporter::*;
142
143mod authorizer;
144pub mod header;
145mod policy;
146mod reporter;
147
148pub struct SecFetchLayer<A = NoopAuthorizer, R = NoopReporter> {
150 enforce: bool,
151 policy: Policy,
152 authorizer: Arc<A>,
153 reporter: Arc<R>,
154}
155
156impl<A, R> Clone for SecFetchLayer<A, R> {
157 fn clone(&self) -> Self {
158 Self {
159 enforce: self.enforce,
160 policy: self.policy,
161 authorizer: self.authorizer.clone(),
162 reporter: self.reporter.clone(),
163 }
164 }
165}
166
167impl Default for SecFetchLayer {
168 fn default() -> Self {
169 Self {
170 enforce: true,
171 policy: Policy::default(),
172 authorizer: Arc::new(NoopAuthorizer),
173 reporter: Arc::new(NoopReporter),
174 }
175 }
176}
177
178impl SecFetchLayer {
179 pub fn new<F>(make_policy: F) -> Self
180 where
181 F: FnOnce(&mut PolicyBuilder),
182 {
183 let mut builder = PolicyBuilder::new();
184 make_policy(&mut builder);
185 let policy = builder.build();
186 Self {
187 policy,
188 ..Default::default()
189 }
190 }
191}
192
193impl<OldA, OldR> SecFetchLayer<OldA, OldR> {
194 pub fn allowing(
195 self,
196 paths: impl Into<Arc<[&'static str]>>,
197 ) -> SecFetchLayer<PathAuthorizer, OldR> {
198 self.with_authorizer(PathAuthorizer::new(paths))
199 }
200
201 pub fn no_enforce(mut self) -> Self {
202 self.enforce = false;
203 self
204 }
205
206 pub fn with_authorizer<A: SecFetchAuthorizer>(self, authorizer: A) -> SecFetchLayer<A, OldR> {
207 SecFetchLayer {
208 enforce: self.enforce,
209 policy: self.policy,
210 authorizer: Arc::from(authorizer),
211 reporter: self.reporter,
212 }
213 }
214
215 pub fn with_reporter<R: SecFetchReporter>(self, reporter: R) -> SecFetchLayer<OldA, R> {
216 SecFetchLayer {
217 enforce: self.enforce,
218 policy: self.policy,
219 authorizer: self.authorizer,
220 reporter: Arc::from(reporter),
221 }
222 }
223}
224
225impl<A, R, S> Layer<S> for SecFetchLayer<A, R> {
226 type Service = SecFetch<A, R, S>;
227
228 fn layer(&self, inner: S) -> Self::Service {
229 SecFetch {
230 enforce: self.enforce,
231 policy: self.policy,
232 authorizer: self.authorizer.clone(),
233 reporter: self.reporter.clone(),
234 inner,
235 }
236 }
237}
238
239pub struct SecFetch<A, R, S> {
241 enforce: bool,
242 policy: Policy,
243 authorizer: Arc<A>,
244 reporter: Arc<R>,
245 inner: S,
246}
247
248impl<A, R, S> Clone for SecFetch<A, R, S>
249where
250 S: Clone,
251{
252 fn clone(&self) -> Self {
253 Self {
254 enforce: self.enforce,
255 policy: self.policy,
256 authorizer: self.authorizer.clone(),
257 reporter: self.reporter.clone(),
258 inner: self.inner.clone(),
259 }
260 }
261}
262
263impl<A, R, ReqB, ResB, S> Service<http::Request<ReqB>> for SecFetch<A, R, S>
264where
265 A: SecFetchAuthorizer,
266 R: SecFetchReporter,
267 S: Service<http::Request<ReqB>, Response = http::Response<ResB>>,
268 ResB: Default,
269{
270 type Response = S::Response;
271
272 type Error = S::Error;
273
274 type Future = Either<S::Future, Ready<Result<Self::Response, Self::Error>>>;
275
276 #[inline]
277 fn poll_ready(
278 &mut self,
279 cx: &mut std::task::Context<'_>,
280 ) -> std::task::Poll<Result<(), Self::Error>> {
281 self.inner.poll_ready(cx)
282 }
283
284 fn call(&mut self, request: http::Request<ReqB>) -> Self::Future {
285 let mut allow = |request| Either::Left(self.inner.call(request));
286 let deny = || {
287 Either::Right(future::ready(Ok(http::Response::builder()
288 .status(StatusCode::FORBIDDEN)
289 .body(ResB::default())
290 .expect("valid response"))))
291 };
292
293 match self.authorizer.authorize(&request) {
294 AuthorizationDecision::Allowed => return allow(request),
295 AuthorizationDecision::Denied => return deny(),
296 AuthorizationDecision::Continue => {}
297 }
298
299 if self.policy.allow(&request) {
300 return allow(request);
301 }
302
303 self.reporter.on_request_denied(&request);
304
305 if !self.enforce {
308 return allow(request);
309 }
310
311 deny()
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use std::sync::atomic::{AtomicBool, Ordering};
318
319 use assert2::{check, let_assert};
320 use http::Method;
321 use tower::ServiceExt;
322 use tower_test::mock;
323
324 use super::*;
325
326 macro_rules! request {
327 (site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
328 request!(::http::Method::GET, "/", site => $site, mode => $mode, dest => $dest)
329 };
330
331 ($path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
332 request!(::http::Method::GET, $path, site => $site, mode => $mode, dest => $dest)
333 };
334
335 ($method:expr, $path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
336 ::http::Request::builder()
337 .method($method)
338 .uri(format!("https://example.com{}", $path))
339 .header(header::SEC_FETCH_SITE, $site)
340 .header(header::SEC_FETCH_MODE, $mode)
341 .header(header::SEC_FETCH_DEST, $dest)
342 .body(())
343 .unwrap()
344 };
345 }
346
347 macro_rules! assert_request {
348 ($req:expr, $assert_resp:expr) => {
349 assert_request!($req, $assert_resp, SecFetchLayer::default())
350 };
351
352 ($req:expr, $assert_resp:expr, $layer:expr) => {
353 let (service, mut handler) =
354 mock::spawn_layer::<http::Request<()>, http::Response<()>, _>($layer);
355
356 tokio::spawn(async move {
357 let_assert!(Some((_, send)) = handler.next_request().await);
358 send.send_response(http::Response::new(()));
359 });
360
361 let response = service.into_inner().oneshot($req).await.unwrap();
362
363 ($assert_resp)(response);
364 };
365 }
366
367 #[tokio::test]
368 async fn it_allows_requests_missing_the_fetch_metadata() {
369 let request = http::Request::new(());
370
371 assert_request!(request, |response: http::Response<()>| {
372 check!(response.status().is_success());
373 });
374 }
375
376 #[tokio::test]
377 async fn it_rejects_requests_missing_the_fetch_metadata_if_configured() {
378 let layer = SecFetchLayer::new(|policy| {
379 policy.reject_missing_metadata();
380 });
381 let request = http::Request::new(());
382
383 assert_request!(
384 request,
385 |response: http::Response<()>| {
386 check!(response.status() == StatusCode::FORBIDDEN);
387 },
388 layer
389 );
390 }
391
392 #[tokio::test]
393 async fn it_allows_same_site_requests() {
394 let request = request!(site => "same-site", mode => "navigate", dest => "document");
395
396 assert_request!(request, |response: http::Response<()>| {
397 check!(response.status().is_success());
398 });
399 }
400
401 #[tokio::test]
402 async fn it_disallows_cross_origin_requests() {
403 let request = request!(site => "cross-site", mode => "cors", dest => "empty");
404
405 assert_request!(request, |response: http::Response<()>| {
406 check!(response.status() == StatusCode::FORBIDDEN);
407 });
408 }
409
410 #[tokio::test]
411 async fn it_allows_cross_origin_requests_safe_methods_if_configured() {
412 let layer = SecFetchLayer::new(|policy| {
413 policy.allow_safe_methods();
414 });
415 let request =
416 request!(Method::GET, "/", site => "cross-site", mode => "cors", dest => "empty");
417
418 assert_request!(
419 request,
420 |response: http::Response<()>| {
421 check!(response.status().is_success());
422 },
423 layer
424 );
425 }
426
427 #[tokio::test]
428 async fn it_allows_navigation_requests() {
429 let request = request!(site => "cross-site", mode => "navigate", dest => "document");
430
431 assert_request!(request, |response: http::Response<()>| {
432 check!(response.status().is_success());
433 });
434 }
435
436 #[tokio::test]
437 async fn it_ignores_explicitely_authorized_requests() {
438 let layer = SecFetchLayer::default().allowing(["/allowed"]);
439 let request = request!("/allowed", site => "cross-site", mode => "cors", dest => "empty");
440
441 assert_request!(
442 request,
443 |response: http::Response<()>| {
444 check!(response.status().is_success());
445 },
446 layer
447 );
448 }
449
450 #[tokio::test]
451 async fn it_allows_denied_requests_if_enforcement_is_turned_off() {
452 let layer = SecFetchLayer::default().no_enforce();
453 let request = request!(site => "cross-site", mode => "cors", dest => "empty");
454
455 assert_request!(
456 request,
457 |response: http::Response<()>| {
458 check!(response.status().is_success());
459 },
460 layer
461 );
462 }
463
464 #[derive(Default)]
465 struct TestReporter {
466 called: AtomicBool,
467 }
468
469 impl SecFetchReporter for TestReporter {
470 fn on_request_denied<B>(&self, _: &http::Request<B>) {
471 self.called.store(true, Ordering::SeqCst);
472 }
473 }
474
475 #[tokio::test]
476 async fn it_reports_a_denied_requests() {
477 let reporter = Arc::new(TestReporter::default());
478 let layer = SecFetchLayer::default().with_reporter(reporter.clone());
479 let request = request!(site => "cross-site", mode => "cors", dest => "empty");
480
481 assert_request!(
482 request,
483 |response: http::Response<()>| {
484 check!(response.status() == StatusCode::FORBIDDEN);
485 },
486 layer
487 );
488
489 let called = reporter.called.load(Ordering::SeqCst);
490 check!(
491 called,
492 "reporter was not called despite the request being rejected"
493 );
494 }
495}