tower_sec_fetch/
lib.rs

1//! # Cookieless CSRF protection library
2//!
3//! This crate provides a [Tower] middleware that implements [Cross-Site-Request-Forgery] protection by validating the [Fetch Metadata] headers of the incoming HTTP request. It does not require cookies, or signing keys, or tokens.
4//!
5//! If you're looking for a classic CSRF cookie implementation, try [tower-surf] instead.
6//!
7//! ## Overview
8//!
9//! For a more in-depth explanation of the problem CSRF protection is trying to solve, and why using signed cookies is not always the best solution, refer to [this excellent writeup](https://github.com/golang/go/issues/73626) by [Filippo Valsorda](https://filippo.io).
10//!
11//! In short, this crate allows to protect web resources from cross-site inclusion and abuse by validating the [Fetch Metadata] headers and ensuring that only "safe" cross-site requests are allowed. In this context, "safe" means:
12//!
13//! - the request comes from the same origin (the site's exact scheme, host, and port), same site (any subdomain of the current domain), or are user-initiated (e.g. clicking on a bookmark, directly entering the website's address), OR...
14//! - the request is a simple GET request coming from a navigation event (e.g. clicking on a link on another website), as long as it's not being embedded in elements like `<object>` or `<iframe>`.
15//!
16//! <div class="warning">
17//!
18//! If the request does not include the Fetch Metadata, such as a request coming from a non-browser user-agent, or a browser released before [2023](https://caniuse.com/mdn-http_headers_sec-fetch-site), the request will be accepted.
19//!
20//! You can change this behaviour by setting the [reject_missing_metadata](PolicyBuilder::reject_missing_metadata) flag on the evaluation policy, but it might make your website not accessible to some users. Note that this is not a good protection against non-browser clients, as they can set the necessary headers anyway.
21//!
22//! </div>
23//!
24//! ## Usage
25//!
26//! Add the library to your Cargo.toml
27//!
28//! ```toml
29//! [dependencies]
30//! tower-sec-fetch = "*"
31//! ```
32//!
33//! Here's how to use it with [Axum], but it works with any tower-based server.
34//!
35//! ```
36//! # use axum::routing::get;
37//! # use tower_sec_fetch::SecFetchLayer;
38//! #
39//! # fn main() {
40//! let routes = axum::Router::new()
41//!     .route("/hello", get(async || "hello"))
42//!     .layer(SecFetchLayer::default());
43//! #
44//! # let routes: axum::Router = routes;
45//! # }
46//! ```
47//!
48//! Specific paths can be explicitely allowed.
49//!
50//! ```
51//! # use axum::routing::get;
52//! # use tower_sec_fetch::SecFetchLayer;
53//! #
54//! # fn main() {
55//! let routes = axum::Router::new()
56//!     .route("/hello", get(async || "hello"))
57//!     .route("/unprotected", get(async || "unprotected"))
58//!     .layer(SecFetchLayer::default().allowing(["/unprotected"]));
59//! #
60//! # let routes: axum::Router = routes;
61//! # }
62//! ```
63//!
64//! You can override the default authorization logic with a custom [SecFetchAuthorizer].
65//!
66//! ```
67//! use tower_sec_fetch::{AuthorizationDecision, SecFetchAuthorizer, SecFetchLayer};
68//!
69//! struct MyAuthorizer;
70//!
71//! impl SecFetchAuthorizer for MyAuthorizer {
72//!    fn authorize<B>(&self, request: &http::Request<B>) -> AuthorizationDecision {
73//!        // allow all requests that come from a specific domain
74//!        if request.uri().host() == Some("my-domain.com") {
75//!            return AuthorizationDecision::Allowed;
76//!        }
77//!
78//!        // otherwise, continue with the regular evaluation policy
79//!        AuthorizationDecision::Continue
80//!    }
81//! }
82//!
83//! SecFetchLayer::default().with_authorizer(MyAuthorizer);
84//! ```
85//!
86//! You can provide a [SecFetchReporter] implementation to be notified of a request being blocked. This can be useful for analytics and monitoring, but also to incrementally introduce this middleware in an existing system where there might be the risk of blocking legitimate requests by accident, when combined with the [no_enforce](SecFetchLayer::no_enforce) flag.
87//!
88//! ```
89//! use tower_sec_fetch::{SecFetchLayer, SecFetchReporter};
90//!
91//! struct LogReporter;
92//!
93//! impl SecFetchReporter for LogReporter {
94//!     fn on_request_denied<B>(&self, request: &http::Request<B>) {
95//!         let uri = request.uri();
96//!         let method = request.method();
97//!         let headers = request.headers();
98//!
99//!         eprintln!("request was denied: {method} {uri} {headers:?}");
100//!     }
101//! }
102//!
103//! SecFetchLayer::default().no_enforce().with_reporter(LogReporter);
104//! ```
105//!
106//! [Safe methods](https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP) are not allowed for cross-origin requests, but this can optionally be disabled by setting the [allow_safe_methods](PolicyBuilder::allow_safe_methods) flag on the evaluation policy.
107//!
108//! ```
109//! # use tower_sec_fetch::SecFetchLayer;
110//! #
111//! SecFetchLayer::new(|policy| {
112//!     policy.allow_safe_methods();
113//! });
114//! ```
115//!
116//! If the Fetch Metadata headers are missing, the request is allowed. This can be disabled by setting the [reject_missing_metadata](PolicyBuilder::reject_missing_metadata) flag on the evaluation policy.
117//!
118//! ```
119//! # use tower_sec_fetch::SecFetchLayer;
120//! #
121//! SecFetchLayer::new(|policy| {
122//!     policy.reject_missing_metadata();
123//! });
124//! ```
125//!
126//! [Tower]: https://docs.rs/tower
127//! [Cross-Site-Request-Forgery]: https://developer.mozilla.org/en-US/docs/Web/Security/Attacks/CSRF
128//! [Fetch Metadata]: https://developer.mozilla.org/en-US/docs/Glossary/Fetch_metadata_request_header
129//! [tower-surf]: https://docs.rs/tower-surf
130//! [Axum]: https://docs.rs/axum
131
132use std::sync::Arc;
133
134use futures::future::{self, Either, Ready};
135use http::StatusCode;
136use policy::Policy;
137use tower::{Layer, Service};
138
139pub use authorizer::*;
140pub use policy::PolicyBuilder;
141pub use reporter::*;
142
143mod authorizer;
144pub mod header;
145mod policy;
146mod reporter;
147
148/// Layer that applies [SecFetch] which validates request against CSRF attacks
149pub struct SecFetchLayer<A = NoopAuthorizer, R = NoopReporter> {
150    enforce: bool,
151    policy: Policy,
152    authorizer: Arc<A>,
153    reporter: Arc<R>,
154}
155
156impl<A, R> Clone for SecFetchLayer<A, R> {
157    fn clone(&self) -> Self {
158        Self {
159            enforce: self.enforce,
160            policy: self.policy,
161            authorizer: self.authorizer.clone(),
162            reporter: self.reporter.clone(),
163        }
164    }
165}
166
167impl Default for SecFetchLayer {
168    fn default() -> Self {
169        Self {
170            enforce: true,
171            policy: Policy::default(),
172            authorizer: Arc::new(NoopAuthorizer),
173            reporter: Arc::new(NoopReporter),
174        }
175    }
176}
177
178impl SecFetchLayer {
179    pub fn new<F>(make_policy: F) -> Self
180    where
181        F: FnOnce(&mut PolicyBuilder),
182    {
183        let mut builder = PolicyBuilder::new();
184        make_policy(&mut builder);
185        let policy = builder.build();
186        Self {
187            policy,
188            ..Default::default()
189        }
190    }
191}
192
193impl<OldA, OldR> SecFetchLayer<OldA, OldR> {
194    pub fn allowing(
195        self,
196        paths: impl Into<Arc<[&'static str]>>,
197    ) -> SecFetchLayer<PathAuthorizer, OldR> {
198        self.with_authorizer(PathAuthorizer::new(paths))
199    }
200
201    pub fn no_enforce(mut self) -> Self {
202        self.enforce = false;
203        self
204    }
205
206    pub fn with_authorizer<A: SecFetchAuthorizer>(self, authorizer: A) -> SecFetchLayer<A, OldR> {
207        SecFetchLayer {
208            enforce: self.enforce,
209            policy: self.policy,
210            authorizer: Arc::from(authorizer),
211            reporter: self.reporter,
212        }
213    }
214
215    pub fn with_reporter<R: SecFetchReporter>(self, reporter: R) -> SecFetchLayer<OldA, R> {
216        SecFetchLayer {
217            enforce: self.enforce,
218            policy: self.policy,
219            authorizer: self.authorizer,
220            reporter: Arc::from(reporter),
221        }
222    }
223}
224
225impl<A, R, S> Layer<S> for SecFetchLayer<A, R> {
226    type Service = SecFetch<A, R, S>;
227
228    fn layer(&self, inner: S) -> Self::Service {
229        SecFetch {
230            enforce: self.enforce,
231            policy: self.policy,
232            authorizer: self.authorizer.clone(),
233            reporter: self.reporter.clone(),
234            inner,
235        }
236    }
237}
238
239/// Middleware protecting against CSRF attacks
240pub struct SecFetch<A, R, S> {
241    enforce: bool,
242    policy: Policy,
243    authorizer: Arc<A>,
244    reporter: Arc<R>,
245    inner: S,
246}
247
248impl<A, R, S> Clone for SecFetch<A, R, S>
249where
250    S: Clone,
251{
252    fn clone(&self) -> Self {
253        Self {
254            enforce: self.enforce,
255            policy: self.policy,
256            authorizer: self.authorizer.clone(),
257            reporter: self.reporter.clone(),
258            inner: self.inner.clone(),
259        }
260    }
261}
262
263impl<A, R, ReqB, ResB, S> Service<http::Request<ReqB>> for SecFetch<A, R, S>
264where
265    A: SecFetchAuthorizer,
266    R: SecFetchReporter,
267    S: Service<http::Request<ReqB>, Response = http::Response<ResB>>,
268    ResB: Default,
269{
270    type Response = S::Response;
271
272    type Error = S::Error;
273
274    type Future = Either<S::Future, Ready<Result<Self::Response, Self::Error>>>;
275
276    #[inline]
277    fn poll_ready(
278        &mut self,
279        cx: &mut std::task::Context<'_>,
280    ) -> std::task::Poll<Result<(), Self::Error>> {
281        self.inner.poll_ready(cx)
282    }
283
284    fn call(&mut self, request: http::Request<ReqB>) -> Self::Future {
285        #[cfg(feature = "tracing")]
286        tracing::debug!(
287            method = %request.method(),
288            path = request.uri().path(),
289            "processing request",
290        );
291
292        let mut allow = |request: http::Request<ReqB>| {
293            #[cfg(feature = "tracing")]
294            tracing::debug!(
295                method = %request.method(),
296                path = request.uri().path(),
297                "request allowed",
298            );
299
300            Either::Left(self.inner.call(request))
301        };
302
303        let deny = || {
304            #[cfg(feature = "tracing")]
305            tracing::debug!(
306                method = %request.method(),
307                path = request.uri().path(),
308                "request denied",
309            );
310
311            Either::Right(future::ready(Ok(http::Response::builder()
312                .status(StatusCode::FORBIDDEN)
313                .body(ResB::default())
314                .expect("valid response"))))
315        };
316
317        match self.authorizer.authorize(&request) {
318            AuthorizationDecision::Allowed => return allow(request),
319            AuthorizationDecision::Denied => return deny(),
320            AuthorizationDecision::Continue => {}
321        }
322
323        if self.policy.allow(&request) {
324            return allow(request);
325        }
326
327        self.reporter.on_request_denied(&request);
328
329        // the request was denied, but we are not enforcing it
330        // we report the failure and let the request continue
331        if !self.enforce {
332            return allow(request);
333        }
334
335        deny()
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use std::sync::atomic::{AtomicBool, Ordering};
342
343    use assert2::{check, let_assert};
344    use http::Method;
345    use tower::ServiceExt;
346    use tower_test::mock;
347
348    use super::*;
349
350    macro_rules! request {
351        (site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
352            request!(::http::Method::GET, "/", site => $site, mode => $mode, dest => $dest)
353        };
354
355        ($path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
356            request!(::http::Method::GET, $path, site => $site, mode => $mode, dest => $dest)
357        };
358
359        ($method:expr, $path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
360            ::http::Request::builder()
361                .method($method)
362                .uri(format!("https://example.com{}", $path))
363                .header(header::SEC_FETCH_SITE, $site)
364                .header(header::SEC_FETCH_MODE, $mode)
365                .header(header::SEC_FETCH_DEST, $dest)
366                .body(())
367                .unwrap()
368        };
369    }
370
371    macro_rules! assert_request {
372        ($req:expr, $assert_resp:expr) => {
373            assert_request!($req, $assert_resp, SecFetchLayer::default())
374        };
375
376        ($req:expr, $assert_resp:expr, $layer:expr) => {
377            let (service, mut handler) =
378                mock::spawn_layer::<http::Request<()>, http::Response<()>, _>($layer);
379
380            tokio::spawn(async move {
381                let_assert!(Some((_, send)) = handler.next_request().await);
382                send.send_response(http::Response::new(()));
383            });
384
385            let response = service.into_inner().oneshot($req).await.unwrap();
386
387            ($assert_resp)(response);
388        };
389    }
390
391    #[tokio::test]
392    async fn it_allows_requests_missing_the_fetch_metadata() {
393        let request = http::Request::new(());
394
395        assert_request!(request, |response: http::Response<()>| {
396            check!(response.status().is_success());
397        });
398    }
399
400    #[tokio::test]
401    async fn it_rejects_requests_missing_the_fetch_metadata_if_configured() {
402        let layer = SecFetchLayer::new(|policy| {
403            policy.reject_missing_metadata();
404        });
405        let request = http::Request::new(());
406
407        assert_request!(
408            request,
409            |response: http::Response<()>| {
410                check!(response.status() == StatusCode::FORBIDDEN);
411            },
412            layer
413        );
414    }
415
416    #[tokio::test]
417    async fn it_allows_same_site_requests() {
418        let request = request!(site => "same-site", mode => "navigate", dest => "document");
419
420        assert_request!(request, |response: http::Response<()>| {
421            check!(response.status().is_success());
422        });
423    }
424
425    #[tokio::test]
426    async fn it_rejects_cross_origin_requests() {
427        let request = request!(site => "cross-site", mode => "cors", dest => "empty");
428
429        assert_request!(request, |response: http::Response<()>| {
430            check!(response.status() == StatusCode::FORBIDDEN);
431        });
432    }
433
434    #[tokio::test]
435    async fn it_allows_cross_origin_requests_safe_methods_if_configured() {
436        let layer = SecFetchLayer::new(|policy| {
437            policy.allow_safe_methods();
438        });
439        let request =
440            request!(Method::GET, "/", site => "cross-site", mode => "cors", dest => "empty");
441
442        assert_request!(
443            request,
444            |response: http::Response<()>| {
445                check!(response.status().is_success());
446            },
447            layer
448        );
449    }
450
451    #[tokio::test]
452    async fn it_allows_navigation_requests() {
453        let request = request!(site => "cross-site", mode => "navigate", dest => "document");
454
455        assert_request!(request, |response: http::Response<()>| {
456            check!(response.status().is_success());
457        });
458    }
459
460    #[tokio::test]
461    async fn it_rejects_navigation_requests_resulting_from_embedding() {
462        let request = request!(site => "cross-site", mode => "navigate", dest => "iframe");
463
464        assert_request!(request, |response: http::Response<()>| {
465            check!(response.status() == StatusCode::FORBIDDEN);
466        });
467    }
468
469    #[tokio::test]
470    async fn it_ignores_explicitely_authorized_requests() {
471        let layer = SecFetchLayer::default().allowing(["/allowed"]);
472        let request = request!("/allowed", site => "cross-site", mode => "cors", dest => "empty");
473
474        assert_request!(
475            request,
476            |response: http::Response<()>| {
477                check!(response.status().is_success());
478            },
479            layer
480        );
481    }
482
483    #[tokio::test]
484    async fn it_allows_denied_requests_if_enforcement_is_turned_off() {
485        let layer = SecFetchLayer::default().no_enforce();
486        let request = request!(site => "cross-site", mode => "cors", dest => "empty");
487
488        assert_request!(
489            request,
490            |response: http::Response<()>| {
491                check!(response.status().is_success());
492            },
493            layer
494        );
495    }
496
497    #[derive(Default)]
498    struct TestReporter {
499        called: AtomicBool,
500    }
501
502    impl SecFetchReporter for TestReporter {
503        fn on_request_denied<B>(&self, _: &http::Request<B>) {
504            self.called.store(true, Ordering::SeqCst);
505        }
506    }
507
508    #[tokio::test]
509    async fn it_reports_a_denied_requests() {
510        let reporter = Arc::new(TestReporter::default());
511        let layer = SecFetchLayer::default().with_reporter(reporter.clone());
512        let request = request!(site => "cross-site", mode => "cors", dest => "empty");
513
514        assert_request!(
515            request,
516            |response: http::Response<()>| {
517                check!(response.status() == StatusCode::FORBIDDEN);
518            },
519            layer
520        );
521
522        let called = reporter.called.load(Ordering::SeqCst);
523        check!(
524            called,
525            "reporter was not called despite the request being rejected"
526        );
527    }
528}