tower_sec_fetch/
lib.rs

1//! # Cookieless CSRF protection library
2//!
3//! This crate provides a [Tower] middleware that implements [Cross-Site-Request-Forgery] protection by validating the [Fetch Metadata] headers of the incoming HTTP request. It does not require cookies, or signing keys, or tokens.
4//!
5//! If you're looking for a classic CSRF cookie implementation, try [tower-surf] instead.
6//!
7//! ## Overview
8//!
9//! For a more in-depth explanation of the problem CSRF protection is trying to solve, and why using signed cookies is not always the best solution, refer to [this excellent writeup](https://github.com/golang/go/issues/73626) by [Filippo Valsorda](https://filippo.io).
10//!
11//! In short, this crate allows to protect web resources from cross-site inclusion and abuse by validating the [Fetch Metadata] headers and ensuring that only "safe" cross-site requests are allowed. In this context, "safe" means:
12//!
13//! - the request comes from the same origin (the site's exact scheme, host, and port), same site (any subdomain of the current domain), or are user-initiated (e.g. clicking on a bookmark, directly entering the website's address), OR...
14//! - the request is a simple GET request coming from a navigation event (e.g. clicking on a link on another website), as long as it's not being embedded in elements like `<object>` or `<iframe>`.
15//!
16//! <div class="warning">
17//!
18//! If the request does not include the Fetch Metadata, such as a request coming from a non-browser user-agent, or a browser released before [2023](https://caniuse.com/mdn-http_headers_sec-fetch-site), the request will be accepted.
19//!
20//! You can change this behaviour by setting the [reject_missing_metadata](PolicyBuilder::reject_missing_metadata) flag on the evaluation policy, but it might make your website not accessible to some users. Note that this is not a good protection against non-browser clients, as they can set the necessary headers anyway.
21//!
22//! </div>
23//!
24//! ## Usage
25//!
26//! Add the library to your Cargo.toml
27//!
28//! ```toml
29//! [dependencies]
30//! tower-sec-fetch = "*"
31//! ```
32//!
33//! Here's how to use it with [Axum], but it works with any tower-based server.
34//!
35//! ```
36//! # use axum::routing::get;
37//! # use tower_sec_fetch::SecFetchLayer;
38//! #
39//! # fn main() {
40//! let routes = axum::Router::new()
41//!     .route("/hello", get(async || "hello"))
42//!     .layer(SecFetchLayer::default());
43//! #
44//! # let routes: axum::Router = routes;
45//! # }
46//! ```
47//!
48//! Specific paths can be explicitely allowed.
49//!
50//! ```
51//! # use axum::routing::get;
52//! # use tower_sec_fetch::SecFetchLayer;
53//! #
54//! # fn main() {
55//! let routes = axum::Router::new()
56//!     .route("/hello", get(async || "hello"))
57//!     .route("/unprotected", get(async || "unprotected"))
58//!     .layer(SecFetchLayer::default().allowing(["/unprotected"]));
59//! #
60//! # let routes: axum::Router = routes;
61//! # }
62//! ```
63//!
64//! You can override the default authorization logic with a custom [SecFetchAuthorizer].
65//!
66//! ```
67//! use tower_sec_fetch::{AuthorizationDecision, SecFetchAuthorizer, SecFetchLayer};
68//!
69//! struct MyAuthorizer;
70//!
71//! impl SecFetchAuthorizer for MyAuthorizer {
72//!    fn authorize<B>(&self, request: &http::Request<B>) -> AuthorizationDecision {
73//!        // allow all requests that come from a specific domain
74//!        if request.uri().host() == Some("my-domain.com") {
75//!            return AuthorizationDecision::Allowed;
76//!        }
77//!
78//!        // otherwise, continue with the regular evaluation policy
79//!        AuthorizationDecision::Continue
80//!    }
81//! }
82//!
83//! SecFetchLayer::default().with_authorizer(MyAuthorizer);
84//! ```
85//!
86//! You can provide a [SecFetchReporter] implementation to be notified of a request being blocked. This can be useful for analytics and monitoring, but also to incrementally introduce this middleware in an existing system where there might be the risk of blocking legitimate requests by accident, when combined with the [no_enforce](SecFetchLayer::no_enforce) flag.
87//!
88//! ```
89//! use tower_sec_fetch::{SecFetchLayer, SecFetchReporter};
90//!
91//! struct LogReporter;
92//!
93//! impl SecFetchReporter for LogReporter {
94//!     fn on_request_denied<B>(&self, request: &http::Request<B>) {
95//!         let uri = request.uri();
96//!         let method = request.method();
97//!         let headers = request.headers();
98//!
99//!         eprintln!("request was denied: {method} {uri} {headers:?}");
100//!     }
101//! }
102//!
103//! SecFetchLayer::default().no_enforce().with_reporter(LogReporter);
104//! ```
105//!
106//! [Safe methods](https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP) are not allowed for cross-origin requests, but this can optionally be disabled by setting the [allow_safe_methods](PolicyBuilder::allow_safe_methods) flag on the evaluation policy.
107//!
108//! ```
109//! # use tower_sec_fetch::SecFetchLayer;
110//! #
111//! SecFetchLayer::new(|policy| {
112//!     policy.allow_safe_methods();
113//! });
114//! ```
115//!
116//! If the Fetch Metadata headers are missing, the request is allowed. This can be disabled by setting the [reject_missing_metadata](PolicyBuilder::reject_missing_metadata) flag on the evaluation policy.
117//!
118//! ```
119//! # use tower_sec_fetch::SecFetchLayer;
120//! #
121//! SecFetchLayer::new(|policy| {
122//!     policy.reject_missing_metadata();
123//! });
124//! ```
125//!
126//! [Tower]: https://docs.rs/tower
127//! [Cross-Site-Request-Forgery]: https://developer.mozilla.org/en-US/docs/Web/Security/Attacks/CSRF
128//! [Fetch Metadata]: https://developer.mozilla.org/en-US/docs/Glossary/Fetch_metadata_request_header
129//! [tower-surf]: https://docs.rs/tower-surf
130//! [Axum]: https://docs.rs/axum
131
132use std::sync::Arc;
133
134use futures::future::{self, Either, Ready};
135use http::StatusCode;
136use policy::Policy;
137use tower::{Layer, Service};
138
139pub use authorizer::*;
140pub use policy::PolicyBuilder;
141pub use reporter::*;
142
143mod authorizer;
144pub mod header;
145mod policy;
146mod reporter;
147
148/// Layer that applies [SecFetch] which validates request against CSRF attacks
149pub struct SecFetchLayer<A = NoopAuthorizer, R = NoopReporter> {
150    enforce: bool,
151    policy: Policy,
152    authorizer: Arc<A>,
153    reporter: Arc<R>,
154}
155
156impl<A, R> Clone for SecFetchLayer<A, R> {
157    fn clone(&self) -> Self {
158        Self {
159            enforce: self.enforce,
160            policy: self.policy,
161            authorizer: self.authorizer.clone(),
162            reporter: self.reporter.clone(),
163        }
164    }
165}
166
167impl Default for SecFetchLayer {
168    fn default() -> Self {
169        Self {
170            enforce: true,
171            policy: Policy::default(),
172            authorizer: Arc::new(NoopAuthorizer),
173            reporter: Arc::new(NoopReporter),
174        }
175    }
176}
177
178impl SecFetchLayer {
179    pub fn new<F>(make_policy: F) -> Self
180    where
181        F: FnOnce(&mut PolicyBuilder),
182    {
183        let mut builder = PolicyBuilder::new();
184        make_policy(&mut builder);
185        let policy = builder.build();
186        Self {
187            policy,
188            ..Default::default()
189        }
190    }
191}
192
193impl<OldA, OldR> SecFetchLayer<OldA, OldR> {
194    pub fn allowing(
195        self,
196        paths: impl Into<Arc<[&'static str]>>,
197    ) -> SecFetchLayer<PathAuthorizer, OldR> {
198        self.with_authorizer(PathAuthorizer::new(paths))
199    }
200
201    pub fn no_enforce(mut self) -> Self {
202        self.enforce = false;
203        self
204    }
205
206    pub fn with_authorizer<A: SecFetchAuthorizer>(self, authorizer: A) -> SecFetchLayer<A, OldR> {
207        SecFetchLayer {
208            enforce: self.enforce,
209            policy: self.policy,
210            authorizer: Arc::from(authorizer),
211            reporter: self.reporter,
212        }
213    }
214
215    pub fn with_reporter<R: SecFetchReporter>(self, reporter: R) -> SecFetchLayer<OldA, R> {
216        SecFetchLayer {
217            enforce: self.enforce,
218            policy: self.policy,
219            authorizer: self.authorizer,
220            reporter: Arc::from(reporter),
221        }
222    }
223}
224
225impl<A, R, S> Layer<S> for SecFetchLayer<A, R> {
226    type Service = SecFetch<A, R, S>;
227
228    fn layer(&self, inner: S) -> Self::Service {
229        SecFetch {
230            enforce: self.enforce,
231            policy: self.policy,
232            authorizer: self.authorizer.clone(),
233            reporter: self.reporter.clone(),
234            inner,
235        }
236    }
237}
238
239/// Middleware protecting against CSRF attacks
240pub struct SecFetch<A, R, S> {
241    enforce: bool,
242    policy: Policy,
243    authorizer: Arc<A>,
244    reporter: Arc<R>,
245    inner: S,
246}
247
248impl<A, R, S> Clone for SecFetch<A, R, S>
249where
250    S: Clone,
251{
252    fn clone(&self) -> Self {
253        Self {
254            enforce: self.enforce,
255            policy: self.policy,
256            authorizer: self.authorizer.clone(),
257            reporter: self.reporter.clone(),
258            inner: self.inner.clone(),
259        }
260    }
261}
262
263impl<A, R, ReqB, ResB, S> Service<http::Request<ReqB>> for SecFetch<A, R, S>
264where
265    A: SecFetchAuthorizer,
266    R: SecFetchReporter,
267    S: Service<http::Request<ReqB>, Response = http::Response<ResB>>,
268    ResB: Default,
269{
270    type Response = S::Response;
271
272    type Error = S::Error;
273
274    type Future = Either<S::Future, Ready<Result<Self::Response, Self::Error>>>;
275
276    #[inline]
277    fn poll_ready(
278        &mut self,
279        cx: &mut std::task::Context<'_>,
280    ) -> std::task::Poll<Result<(), Self::Error>> {
281        self.inner.poll_ready(cx)
282    }
283
284    fn call(&mut self, request: http::Request<ReqB>) -> Self::Future {
285        let mut allow = |request| Either::Left(self.inner.call(request));
286        let deny = || {
287            Either::Right(future::ready(Ok(http::Response::builder()
288                .status(StatusCode::FORBIDDEN)
289                .body(ResB::default())
290                .expect("valid response"))))
291        };
292
293        match self.authorizer.authorize(&request) {
294            AuthorizationDecision::Allowed => return allow(request),
295            AuthorizationDecision::Denied => return deny(),
296            AuthorizationDecision::Continue => {}
297        }
298
299        if self.policy.allow(&request) {
300            return allow(request);
301        }
302
303        self.reporter.on_request_denied(&request);
304
305        // the request was denied, but we are not enforcing it
306        // we report the failure and let the request continue
307        if !self.enforce {
308            return allow(request);
309        }
310
311        deny()
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use std::sync::atomic::{AtomicBool, Ordering};
318
319    use assert2::{check, let_assert};
320    use http::Method;
321    use tower::ServiceExt;
322    use tower_test::mock;
323
324    use super::*;
325
326    macro_rules! request {
327        (site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
328            request!(::http::Method::GET, "/", site => $site, mode => $mode, dest => $dest)
329        };
330
331        ($path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
332            request!(::http::Method::GET, $path, site => $site, mode => $mode, dest => $dest)
333        };
334
335        ($method:expr, $path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
336            ::http::Request::builder()
337                .method($method)
338                .uri(format!("https://example.com{}", $path))
339                .header(header::SEC_FETCH_SITE, $site)
340                .header(header::SEC_FETCH_MODE, $mode)
341                .header(header::SEC_FETCH_DEST, $dest)
342                .body(())
343                .unwrap()
344        };
345    }
346
347    macro_rules! assert_request {
348        ($req:expr, $assert_resp:expr) => {
349            assert_request!($req, $assert_resp, SecFetchLayer::default())
350        };
351
352        ($req:expr, $assert_resp:expr, $layer:expr) => {
353            let (service, mut handler) =
354                mock::spawn_layer::<http::Request<()>, http::Response<()>, _>($layer);
355
356            tokio::spawn(async move {
357                let_assert!(Some((_, send)) = handler.next_request().await);
358                send.send_response(http::Response::new(()));
359            });
360
361            let response = service.into_inner().oneshot($req).await.unwrap();
362
363            ($assert_resp)(response);
364        };
365    }
366
367    #[tokio::test]
368    async fn it_allows_requests_missing_the_fetch_metadata() {
369        let request = http::Request::new(());
370
371        assert_request!(request, |response: http::Response<()>| {
372            check!(response.status().is_success());
373        });
374    }
375
376    #[tokio::test]
377    async fn it_rejects_requests_missing_the_fetch_metadata_if_configured() {
378        let layer = SecFetchLayer::new(|policy| {
379            policy.reject_missing_metadata();
380        });
381        let request = http::Request::new(());
382
383        assert_request!(
384            request,
385            |response: http::Response<()>| {
386                check!(response.status() == StatusCode::FORBIDDEN);
387            },
388            layer
389        );
390    }
391
392    #[tokio::test]
393    async fn it_allows_same_site_requests() {
394        let request = request!(site => "same-site", mode => "navigate", dest => "document");
395
396        assert_request!(request, |response: http::Response<()>| {
397            check!(response.status().is_success());
398        });
399    }
400
401    #[tokio::test]
402    async fn it_disallows_cross_origin_requests() {
403        let request = request!(site => "cross-site", mode => "cors", dest => "empty");
404
405        assert_request!(request, |response: http::Response<()>| {
406            check!(response.status() == StatusCode::FORBIDDEN);
407        });
408    }
409
410    #[tokio::test]
411    async fn it_allows_cross_origin_requests_safe_methods_if_configured() {
412        let layer = SecFetchLayer::new(|policy| {
413            policy.allow_safe_methods();
414        });
415        let request =
416            request!(Method::GET, "/", site => "cross-site", mode => "cors", dest => "empty");
417
418        assert_request!(
419            request,
420            |response: http::Response<()>| {
421                check!(response.status().is_success());
422            },
423            layer
424        );
425    }
426
427    #[tokio::test]
428    async fn it_allows_navigation_requests() {
429        let request = request!(site => "cross-site", mode => "navigate", dest => "document");
430
431        assert_request!(request, |response: http::Response<()>| {
432            check!(response.status().is_success());
433        });
434    }
435
436    #[tokio::test]
437    async fn it_ignores_explicitely_authorized_requests() {
438        let layer = SecFetchLayer::default().allowing(["/allowed"]);
439        let request = request!("/allowed", site => "cross-site", mode => "cors", dest => "empty");
440
441        assert_request!(
442            request,
443            |response: http::Response<()>| {
444                check!(response.status().is_success());
445            },
446            layer
447        );
448    }
449
450    #[tokio::test]
451    async fn it_allows_denied_requests_if_enforcement_is_turned_off() {
452        let layer = SecFetchLayer::default().no_enforce();
453        let request = request!(site => "cross-site", mode => "cors", dest => "empty");
454
455        assert_request!(
456            request,
457            |response: http::Response<()>| {
458                check!(response.status().is_success());
459            },
460            layer
461        );
462    }
463
464    #[derive(Default)]
465    struct TestReporter {
466        called: AtomicBool,
467    }
468
469    impl SecFetchReporter for TestReporter {
470        fn on_request_denied<B>(&self, _: &http::Request<B>) {
471            self.called.store(true, Ordering::SeqCst);
472        }
473    }
474
475    #[tokio::test]
476    async fn it_reports_a_denied_requests() {
477        let reporter = Arc::new(TestReporter::default());
478        let layer = SecFetchLayer::default().with_reporter(reporter.clone());
479        let request = request!(site => "cross-site", mode => "cors", dest => "empty");
480
481        assert_request!(
482            request,
483            |response: http::Response<()>| {
484                check!(response.status() == StatusCode::FORBIDDEN);
485            },
486            layer
487        );
488
489        let called = reporter.called.load(Ordering::SeqCst);
490        check!(
491            called,
492            "reporter was not called despite the request being rejected"
493        );
494    }
495}