1use std::sync::Arc;
133
134use futures::future::{self, Either, Ready};
135use http::StatusCode;
136use policy::Policy;
137use tower::{Layer, Service};
138
139pub use authorizer::*;
140pub use policy::PolicyBuilder;
141pub use reporter::*;
142
143mod authorizer;
144pub mod header;
145mod policy;
146mod reporter;
147
148pub struct SecFetchLayer<A = NoopAuthorizer, R = NoopReporter> {
150 enforce: bool,
151 policy: Policy,
152 authorizer: Arc<A>,
153 reporter: Arc<R>,
154}
155
156impl<A, R> Clone for SecFetchLayer<A, R> {
157 fn clone(&self) -> Self {
158 Self {
159 enforce: self.enforce,
160 policy: self.policy,
161 authorizer: self.authorizer.clone(),
162 reporter: self.reporter.clone(),
163 }
164 }
165}
166
167impl Default for SecFetchLayer {
168 fn default() -> Self {
169 Self {
170 enforce: true,
171 policy: Policy::default(),
172 authorizer: Arc::new(NoopAuthorizer),
173 reporter: Arc::new(NoopReporter),
174 }
175 }
176}
177
178impl SecFetchLayer {
179 pub fn new<F>(make_policy: F) -> Self
180 where
181 F: FnOnce(&mut PolicyBuilder),
182 {
183 let mut builder = PolicyBuilder::new();
184 make_policy(&mut builder);
185 let policy = builder.build();
186 Self {
187 policy,
188 ..Default::default()
189 }
190 }
191}
192
193impl<OldA, OldR> SecFetchLayer<OldA, OldR> {
194 pub fn allowing(
195 self,
196 paths: impl Into<Arc<[&'static str]>>,
197 ) -> SecFetchLayer<PathAuthorizer, OldR> {
198 self.with_authorizer(PathAuthorizer::new(paths))
199 }
200
201 pub fn no_enforce(mut self) -> Self {
202 self.enforce = false;
203 self
204 }
205
206 pub fn with_authorizer<A: SecFetchAuthorizer>(self, authorizer: A) -> SecFetchLayer<A, OldR> {
207 SecFetchLayer {
208 enforce: self.enforce,
209 policy: self.policy,
210 authorizer: Arc::from(authorizer),
211 reporter: self.reporter,
212 }
213 }
214
215 pub fn with_reporter<R: SecFetchReporter>(self, reporter: R) -> SecFetchLayer<OldA, R> {
216 SecFetchLayer {
217 enforce: self.enforce,
218 policy: self.policy,
219 authorizer: self.authorizer,
220 reporter: Arc::from(reporter),
221 }
222 }
223}
224
225impl<A, R, S> Layer<S> for SecFetchLayer<A, R> {
226 type Service = SecFetch<A, R, S>;
227
228 fn layer(&self, inner: S) -> Self::Service {
229 SecFetch {
230 enforce: self.enforce,
231 policy: self.policy,
232 authorizer: self.authorizer.clone(),
233 reporter: self.reporter.clone(),
234 inner,
235 }
236 }
237}
238
239pub struct SecFetch<A, R, S> {
241 enforce: bool,
242 policy: Policy,
243 authorizer: Arc<A>,
244 reporter: Arc<R>,
245 inner: S,
246}
247
248impl<A, R, S> Clone for SecFetch<A, R, S>
249where
250 S: Clone,
251{
252 fn clone(&self) -> Self {
253 Self {
254 enforce: self.enforce,
255 policy: self.policy,
256 authorizer: self.authorizer.clone(),
257 reporter: self.reporter.clone(),
258 inner: self.inner.clone(),
259 }
260 }
261}
262
263impl<A, R, ReqB, ResB, S> Service<http::Request<ReqB>> for SecFetch<A, R, S>
264where
265 A: SecFetchAuthorizer,
266 R: SecFetchReporter,
267 S: Service<http::Request<ReqB>, Response = http::Response<ResB>>,
268 ResB: Default,
269{
270 type Response = S::Response;
271
272 type Error = S::Error;
273
274 type Future = Either<S::Future, Ready<Result<Self::Response, Self::Error>>>;
275
276 #[inline]
277 fn poll_ready(
278 &mut self,
279 cx: &mut std::task::Context<'_>,
280 ) -> std::task::Poll<Result<(), Self::Error>> {
281 self.inner.poll_ready(cx)
282 }
283
284 fn call(&mut self, request: http::Request<ReqB>) -> Self::Future {
285 #[cfg(feature = "tracing")]
286 tracing::debug!(
287 method = %request.method(),
288 path = request.uri().path(),
289 "processing request",
290 );
291
292 let mut allow = |request: http::Request<ReqB>| {
293 #[cfg(feature = "tracing")]
294 tracing::debug!(
295 method = %request.method(),
296 path = request.uri().path(),
297 "request allowed",
298 );
299
300 Either::Left(self.inner.call(request))
301 };
302
303 let deny = || {
304 #[cfg(feature = "tracing")]
305 tracing::debug!(
306 method = %request.method(),
307 path = request.uri().path(),
308 "request denied",
309 );
310
311 Either::Right(future::ready(Ok(http::Response::builder()
312 .status(StatusCode::FORBIDDEN)
313 .body(ResB::default())
314 .expect("valid response"))))
315 };
316
317 match self.authorizer.authorize(&request) {
318 AuthorizationDecision::Allowed => return allow(request),
319 AuthorizationDecision::Denied => return deny(),
320 AuthorizationDecision::Continue => {}
321 }
322
323 if self.policy.allow(&request) {
324 return allow(request);
325 }
326
327 self.reporter.on_request_denied(&request);
328
329 if !self.enforce {
332 return allow(request);
333 }
334
335 deny()
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use std::sync::atomic::{AtomicBool, Ordering};
342
343 use assert2::{check, let_assert};
344 use http::Method;
345 use tower::ServiceExt;
346 use tower_test::mock;
347
348 use super::*;
349
350 macro_rules! request {
351 (site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
352 request!(::http::Method::GET, "/", site => $site, mode => $mode, dest => $dest)
353 };
354
355 ($path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
356 request!(::http::Method::GET, $path, site => $site, mode => $mode, dest => $dest)
357 };
358
359 ($method:expr, $path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
360 ::http::Request::builder()
361 .method($method)
362 .uri(format!("https://example.com{}", $path))
363 .header(header::SEC_FETCH_SITE, $site)
364 .header(header::SEC_FETCH_MODE, $mode)
365 .header(header::SEC_FETCH_DEST, $dest)
366 .body(())
367 .unwrap()
368 };
369 }
370
371 macro_rules! assert_request {
372 ($req:expr, $assert_resp:expr) => {
373 assert_request!($req, $assert_resp, SecFetchLayer::default())
374 };
375
376 ($req:expr, $assert_resp:expr, $layer:expr) => {
377 let (service, mut handler) =
378 mock::spawn_layer::<http::Request<()>, http::Response<()>, _>($layer);
379
380 tokio::spawn(async move {
381 let_assert!(Some((_, send)) = handler.next_request().await);
382 send.send_response(http::Response::new(()));
383 });
384
385 let response = service.into_inner().oneshot($req).await.unwrap();
386
387 ($assert_resp)(response);
388 };
389 }
390
391 #[tokio::test]
392 async fn it_allows_requests_missing_the_fetch_metadata() {
393 let request = http::Request::new(());
394
395 assert_request!(request, |response: http::Response<()>| {
396 check!(response.status().is_success());
397 });
398 }
399
400 #[tokio::test]
401 async fn it_rejects_requests_missing_the_fetch_metadata_if_configured() {
402 let layer = SecFetchLayer::new(|policy| {
403 policy.reject_missing_metadata();
404 });
405 let request = http::Request::new(());
406
407 assert_request!(
408 request,
409 |response: http::Response<()>| {
410 check!(response.status() == StatusCode::FORBIDDEN);
411 },
412 layer
413 );
414 }
415
416 #[tokio::test]
417 async fn it_allows_same_site_requests() {
418 let request = request!(site => "same-site", mode => "navigate", dest => "document");
419
420 assert_request!(request, |response: http::Response<()>| {
421 check!(response.status().is_success());
422 });
423 }
424
425 #[tokio::test]
426 async fn it_rejects_cross_origin_requests() {
427 let request = request!(site => "cross-site", mode => "cors", dest => "empty");
428
429 assert_request!(request, |response: http::Response<()>| {
430 check!(response.status() == StatusCode::FORBIDDEN);
431 });
432 }
433
434 #[tokio::test]
435 async fn it_allows_cross_origin_requests_safe_methods_if_configured() {
436 let layer = SecFetchLayer::new(|policy| {
437 policy.allow_safe_methods();
438 });
439 let request =
440 request!(Method::GET, "/", site => "cross-site", mode => "cors", dest => "empty");
441
442 assert_request!(
443 request,
444 |response: http::Response<()>| {
445 check!(response.status().is_success());
446 },
447 layer
448 );
449 }
450
451 #[tokio::test]
452 async fn it_allows_navigation_requests() {
453 let request = request!(site => "cross-site", mode => "navigate", dest => "document");
454
455 assert_request!(request, |response: http::Response<()>| {
456 check!(response.status().is_success());
457 });
458 }
459
460 #[tokio::test]
461 async fn it_rejects_navigation_requests_resulting_from_embedding() {
462 let request = request!(site => "cross-site", mode => "navigate", dest => "iframe");
463
464 assert_request!(request, |response: http::Response<()>| {
465 check!(response.status() == StatusCode::FORBIDDEN);
466 });
467 }
468
469 #[tokio::test]
470 async fn it_ignores_explicitely_authorized_requests() {
471 let layer = SecFetchLayer::default().allowing(["/allowed"]);
472 let request = request!("/allowed", site => "cross-site", mode => "cors", dest => "empty");
473
474 assert_request!(
475 request,
476 |response: http::Response<()>| {
477 check!(response.status().is_success());
478 },
479 layer
480 );
481 }
482
483 #[tokio::test]
484 async fn it_allows_denied_requests_if_enforcement_is_turned_off() {
485 let layer = SecFetchLayer::default().no_enforce();
486 let request = request!(site => "cross-site", mode => "cors", dest => "empty");
487
488 assert_request!(
489 request,
490 |response: http::Response<()>| {
491 check!(response.status().is_success());
492 },
493 layer
494 );
495 }
496
497 #[derive(Default)]
498 struct TestReporter {
499 called: AtomicBool,
500 }
501
502 impl SecFetchReporter for TestReporter {
503 fn on_request_denied<B>(&self, _: &http::Request<B>) {
504 self.called.store(true, Ordering::SeqCst);
505 }
506 }
507
508 #[tokio::test]
509 async fn it_reports_a_denied_requests() {
510 let reporter = Arc::new(TestReporter::default());
511 let layer = SecFetchLayer::default().with_reporter(reporter.clone());
512 let request = request!(site => "cross-site", mode => "cors", dest => "empty");
513
514 assert_request!(
515 request,
516 |response: http::Response<()>| {
517 check!(response.status() == StatusCode::FORBIDDEN);
518 },
519 layer
520 );
521
522 let called = reporter.called.load(Ordering::SeqCst);
523 check!(
524 called,
525 "reporter was not called despite the request being rejected"
526 );
527 }
528}