1use std::sync::Arc;
132
133use futures::future::{self, Either, Ready};
134use http::StatusCode;
135use policy::Policy;
136use tower::{Layer, Service};
137
138pub use authorizer::*;
139pub use policy::PolicyBuilder;
140pub use reporter::*;
141
142mod authorizer;
143pub mod header;
144mod policy;
145mod reporter;
146
147pub struct SecFetchLayer<A = NoopAuthorizer, R = NoopReporter> {
149 enforce: bool,
150 policy: Policy,
151 authorizer: Arc<A>,
152 reporter: Arc<R>,
153}
154
155impl<A, R> Clone for SecFetchLayer<A, R> {
156 fn clone(&self) -> Self {
157 Self {
158 enforce: self.enforce,
159 policy: self.policy,
160 authorizer: self.authorizer.clone(),
161 reporter: self.reporter.clone(),
162 }
163 }
164}
165
166impl Default for SecFetchLayer {
167 fn default() -> Self {
168 Self {
169 enforce: true,
170 policy: Policy::default(),
171 authorizer: Arc::new(NoopAuthorizer),
172 reporter: Arc::new(NoopReporter),
173 }
174 }
175}
176
177impl SecFetchLayer {
178 pub fn new<F>(make_policy: F) -> Self
179 where
180 F: FnOnce(&mut PolicyBuilder),
181 {
182 let mut builder = PolicyBuilder::new();
183 make_policy(&mut builder);
184 let policy = builder.build();
185 Self {
186 policy,
187 ..Default::default()
188 }
189 }
190}
191
192impl<OldA, OldR> SecFetchLayer<OldA, OldR> {
193 pub fn allowing(
194 self,
195 paths: impl Into<Arc<[&'static str]>>,
196 ) -> SecFetchLayer<PathAuthorizer, OldR> {
197 self.with_authorizer(PathAuthorizer::new(paths))
198 }
199
200 pub fn no_enforce(mut self) -> Self {
201 self.enforce = false;
202 self
203 }
204
205 pub fn with_authorizer<A: SecFetchAuthorizer>(self, authorizer: A) -> SecFetchLayer<A, OldR> {
206 SecFetchLayer {
207 enforce: self.enforce,
208 policy: self.policy,
209 authorizer: Arc::from(authorizer),
210 reporter: self.reporter,
211 }
212 }
213
214 pub fn with_reporter<R: SecFetchReporter>(self, reporter: R) -> SecFetchLayer<OldA, R> {
215 SecFetchLayer {
216 enforce: self.enforce,
217 policy: self.policy,
218 authorizer: self.authorizer,
219 reporter: Arc::from(reporter),
220 }
221 }
222}
223
224impl<A, R, S> Layer<S> for SecFetchLayer<A, R> {
225 type Service = SecFetch<A, R, S>;
226
227 fn layer(&self, inner: S) -> Self::Service {
228 SecFetch {
229 enforce: self.enforce,
230 policy: self.policy,
231 authorizer: self.authorizer.clone(),
232 reporter: self.reporter.clone(),
233 inner,
234 }
235 }
236}
237
238pub struct SecFetch<A, R, S> {
240 enforce: bool,
241 policy: Policy,
242 authorizer: Arc<A>,
243 reporter: Arc<R>,
244 inner: S,
245}
246
247impl<A, R, S> Clone for SecFetch<A, R, S>
248where
249 S: Clone,
250{
251 fn clone(&self) -> Self {
252 Self {
253 enforce: self.enforce,
254 policy: self.policy,
255 authorizer: self.authorizer.clone(),
256 reporter: self.reporter.clone(),
257 inner: self.inner.clone(),
258 }
259 }
260}
261
262impl<A, R, ReqB, ResB, S> Service<http::Request<ReqB>> for SecFetch<A, R, S>
263where
264 A: SecFetchAuthorizer,
265 R: SecFetchReporter,
266 S: Service<http::Request<ReqB>, Response = http::Response<ResB>>,
267 ResB: Default,
268{
269 type Response = S::Response;
270
271 type Error = S::Error;
272
273 type Future = Either<S::Future, Ready<Result<Self::Response, Self::Error>>>;
274
275 #[inline]
276 fn poll_ready(
277 &mut self,
278 cx: &mut std::task::Context<'_>,
279 ) -> std::task::Poll<Result<(), Self::Error>> {
280 self.inner.poll_ready(cx)
281 }
282
283 fn call(&mut self, request: http::Request<ReqB>) -> Self::Future {
284 let mut allow = |request| Either::Left(self.inner.call(request));
285 let deny = || {
286 Either::Right(future::ready(Ok(http::Response::builder()
287 .status(StatusCode::FORBIDDEN)
288 .body(ResB::default())
289 .expect("valid response"))))
290 };
291
292 match self.authorizer.authorize(&request) {
293 AuthorizationDecision::Allowed => return allow(request),
294 AuthorizationDecision::Denied => return deny(),
295 AuthorizationDecision::Continue => {}
296 }
297
298 if self.policy.allow(&request) {
299 return allow(request);
300 }
301
302 self.reporter.on_request_denied(&request);
303
304 if !self.enforce {
307 return allow(request);
308 }
309
310 deny()
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use std::sync::atomic::{AtomicBool, Ordering};
317
318 use assert2::{check, let_assert};
319 use http::Method;
320 use tower::ServiceExt;
321 use tower_test::mock;
322
323 use super::*;
324
325 macro_rules! request {
326 (site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
327 request!(::http::Method::GET, "/", site => $site, mode => $mode, dest => $dest)
328 };
329
330 ($path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
331 request!(::http::Method::GET, $path, site => $site, mode => $mode, dest => $dest)
332 };
333
334 ($method:expr, $path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
335 ::http::Request::builder()
336 .method($method)
337 .uri(format!("https://example.com{}", $path))
338 .header(header::SEC_FETCH_SITE, $site)
339 .header(header::SEC_FETCH_MODE, $mode)
340 .header(header::SEC_FETCH_DEST, $dest)
341 .body(())
342 .unwrap()
343 };
344 }
345
346 macro_rules! assert_request {
347 ($req:expr, $assert_resp:expr) => {
348 assert_request!($req, $assert_resp, SecFetchLayer::default())
349 };
350
351 ($req:expr, $assert_resp:expr, $layer:expr) => {
352 let (service, mut handler) =
353 mock::spawn_layer::<http::Request<()>, http::Response<()>, _>($layer);
354
355 tokio::spawn(async move {
356 let_assert!(Some((_, send)) = handler.next_request().await);
357 send.send_response(http::Response::new(()));
358 });
359
360 let response = service.into_inner().oneshot($req).await.unwrap();
361
362 ($assert_resp)(response);
363 };
364 }
365
366 #[tokio::test]
367 async fn it_allows_requests_missing_the_fetch_metadata() {
368 let request = http::Request::new(());
369
370 assert_request!(request, |response: http::Response<()>| {
371 check!(response.status().is_success());
372 });
373 }
374
375 #[tokio::test]
376 async fn it_rejects_requests_missing_the_fetch_metadata_if_configured() {
377 let layer = SecFetchLayer::new(|policy| {
378 policy.reject_missing_metadata();
379 });
380 let request = http::Request::new(());
381
382 assert_request!(
383 request,
384 |response: http::Response<()>| {
385 check!(response.status() == StatusCode::FORBIDDEN);
386 },
387 layer
388 );
389 }
390
391 #[tokio::test]
392 async fn it_allows_same_site_requests() {
393 let request = request!(site => "same-site", mode => "navigate", dest => "document");
394
395 assert_request!(request, |response: http::Response<()>| {
396 check!(response.status().is_success());
397 });
398 }
399
400 #[tokio::test]
401 async fn it_disallows_cross_origin_requests() {
402 let request = request!(site => "cross-site", mode => "cors", dest => "empty");
403
404 assert_request!(request, |response: http::Response<()>| {
405 check!(response.status() == StatusCode::FORBIDDEN);
406 });
407 }
408
409 #[tokio::test]
410 async fn it_allows_cross_origin_requests_safe_methods_if_configured() {
411 let layer = SecFetchLayer::new(|policy| {
412 policy.allow_safe_methods();
413 });
414 let request =
415 request!(Method::GET, "/", site => "cross-site", mode => "cors", dest => "empty");
416
417 assert_request!(
418 request,
419 |response: http::Response<()>| {
420 check!(response.status().is_success());
421 },
422 layer
423 );
424 }
425
426 #[tokio::test]
427 async fn it_allows_navigation_requests() {
428 let request = request!(site => "cross-site", mode => "navigate", dest => "document");
429
430 assert_request!(request, |response: http::Response<()>| {
431 check!(response.status().is_success());
432 });
433 }
434
435 #[tokio::test]
436 async fn it_ignores_explicitely_authorized_requests() {
437 let layer = SecFetchLayer::default().allowing(["/allowed"]);
438 let request = request!("/allowed", site => "cross-site", mode => "cors", dest => "empty");
439
440 assert_request!(
441 request,
442 |response: http::Response<()>| {
443 check!(response.status().is_success());
444 },
445 layer
446 );
447 }
448
449 #[tokio::test]
450 async fn it_allows_denied_requests_if_enforcement_is_turned_off() {
451 let layer = SecFetchLayer::default().no_enforce();
452 let request = request!(site => "cross-site", mode => "cors", dest => "empty");
453
454 assert_request!(
455 request,
456 |response: http::Response<()>| {
457 check!(response.status().is_success());
458 },
459 layer
460 );
461 }
462
463 #[derive(Default)]
464 struct TestReporter {
465 called: AtomicBool,
466 }
467
468 impl SecFetchReporter for TestReporter {
469 fn on_request_denied<B>(&self, _: &http::Request<B>) {
470 self.called.store(true, Ordering::SeqCst);
471 }
472 }
473
474 #[tokio::test]
475 async fn it_reports_a_denied_requests() {
476 let reporter = Arc::new(TestReporter::default());
477 let layer = SecFetchLayer::default().with_reporter(reporter.clone());
478 let request = request!(site => "cross-site", mode => "cors", dest => "empty");
479
480 assert_request!(
481 request,
482 |response: http::Response<()>| {
483 check!(response.status() == StatusCode::FORBIDDEN);
484 },
485 layer
486 );
487
488 let called = reporter.called.load(Ordering::SeqCst);
489 check!(
490 called,
491 "reporter was not called despite the request being rejected"
492 );
493 }
494}