tower_sec_fetch/
lib.rs

1//! # Cookieless CSRF protection library
2//!
3//! This crate provides a [Tower] middleware that implements CSRF protection by validation the [Fetch Metadata] headers of the incoming HTTP request. It does not require cookies, or signing keys, or tokens.
4//!
5//! If you're looking for a classic CSRF cookie implementation, try [tower-surf] instead.
6//!
7//! ## Overview
8//!
9//! For a more in-depth explanation of the problem CSRF is trying to solve, and why using signed cookies is not the best solution, refer to [this excellent writeup](https://github.com/golang/go/issues/73626) by [Filippo Valsorda](https://filippo.io).
10//!
11//! In short, this crate allows to protect web resources from cross-site inclusion and abuse by validating the [Fetch Metadata] headers and ensuring that only "safe" cross-site requests are allowed. In this context, "safe" means:
12//!
13//! - the request comes from the same origin (the site's exact scheme, host, and port), same site (any subdomain of the current domain), or are user-initiated (e.g. clicking on a bookmark, directly entering the website's address...)
14//! - the request is a regular GET request coming from a navigation event (e.g. clicking on a link on another website), as long as it's not being embedded in elements like `<object>` or `<iframe>`.
15//!
16//! <div class="warning">
17//!
18//! If the request does not include the Fetch Metadata, such as a request coming from a non-browser user-agent, or a browser released before [2023](https://caniuse.com/mdn-http_headers_sec-fetch-site), the request will be accepted.
19//!
20//! You can change this behaviour by setting the [reject_missing_metadata](PolicyBuilder::reject_missing_metadata) flag on the evaluation policy, but it might make your website not accessible to some users. Note that this is not a good protection against non-browser clients, as they can set the necessary headers anyway.
21//!
22//! </div>
23//!
24//! ## Usage
25//!
26//! Add the library to your Cargo.toml
27//!
28//! ```toml
29//! [dependencies]
30//! tower-sec-fetch = "*"
31//! ```
32//!
33//! Here's how to use it with [Axum], but it works with any tower-based server.
34//!
35//! ```
36//! # use axum::routing::get;
37//! # use tower_sec_fetch::SecFetchLayer;
38//! #
39//! # fn main() {
40//! let routes = axum::Router::new()
41//!     .route("/hello", get(async || "hello"))
42//!     .layer(SecFetchLayer::default());
43//! #
44//! # let routes: axum::Router = routes;
45//! # }
46//! ```
47//!
48//! Specific paths can be explicitely allowed.
49//!
50//! ```
51//! # use axum::routing::get;
52//! # use tower_sec_fetch::SecFetchLayer;
53//! #
54//! # fn main() {
55//! let routes = axum::Router::new()
56//!     .route("/hello", get(async || "hello"))
57//!     .route("/unprotected", get(async || "unprotected"))
58//!     .layer(SecFetchLayer::default().allowing(["/unprotected"]));
59//! #
60//! # let routes: axum::Router = routes;
61//! # }
62//! ```
63//!
64//! You can override the default authorization logic with a custom [SecFetchAuthorizer].
65//!
66//! ```
67//! use tower_sec_fetch::{AuthorizationDecision, SecFetchAuthorizer, SecFetchLayer};
68//!
69//! struct MyAuthorizer;
70//!
71//! impl SecFetchAuthorizer for MyAuthorizer {
72//!    fn authorize<B>(&self, request: &http::Request<B>) -> AuthorizationDecision {
73//!        // allow all requests that come from a specific domain
74//!        if request.uri().host() == Some("my-domain.com") {
75//!            return AuthorizationDecision::Allowed;
76//!        }
77//!
78//!        // otherwise, continue with the regular evaluation policy
79//!        AuthorizationDecision::Continue
80//!    }
81//! }
82//!
83//! SecFetchLayer::default().with_authorizer(MyAuthorizer);
84//! ```
85//!
86//! You can provide a [SecFetchReporter] implementation to be notified of a request being blocked. This can be useful for analytics and monitoring, but also to incrementally introduce this middleware in an existing system where there might be the risk of blocking legitimate requests by accident, when combined with the [no_enforce](SecFetchLayer::no_enforce) flag.
87//!
88//! ```
89//! use tower_sec_fetch::{SecFetchLayer, SecFetchReporter};
90//!
91//! struct LogReporter;
92//!
93//! impl SecFetchReporter for LogReporter {
94//!     fn on_request_denied<B>(&self, request: &http::Request<B>) {
95//!         let uri = request.uri();
96//!         let method = request.method();
97//!         let headers = request.headers();
98//!
99//!         eprintln!("request was denied: {method} {uri} {headers:?}");
100//!     }
101//! }
102//!
103//! SecFetchLayer::default().no_enforce().with_reporter(LogReporter);
104//! ```
105//!
106//! [Safe methods](https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP) are not allowed for cross-origin requests, but this can optionally be disabled by setting the [allow_safe_methods](PolicyBuilder::allow_safe_methods) flag on the evaluation policy.
107//!
108//! ```
109//! # use tower_sec_fetch::SecFetchLayer;
110//! #
111//! SecFetchLayer::new(|policy| {
112//!     policy.allow_safe_methods();
113//! });
114//! ```
115//!
116//! If the Fetch Metadata headers are missing, the request is allowed. This can be disabled by setting the [reject_missing_metadata](PolicyBuilder::reject_missing_metadata) flag on the evaluation policy.
117//!
118//! ```
119//! # use tower_sec_fetch::SecFetchLayer;
120//! #
121//! SecFetchLayer::new(|policy| {
122//!     policy.reject_missing_metadata();
123//! });
124//! ```
125//!
126//! [Tower]: https://docs.rs/tower
127//! [Fetch Metadata]: https://developer.mozilla.org/en-US/docs/Glossary/Fetch_metadata_request_header
128//! [tower-surf]: https://docs.rs/tower-surf
129//! [Axum]: https://docs.rs/axum
130
131use std::sync::Arc;
132
133use futures::future::{self, Either, Ready};
134use http::StatusCode;
135use policy::Policy;
136use tower::{Layer, Service};
137
138pub use authorizer::*;
139pub use policy::PolicyBuilder;
140pub use reporter::*;
141
142mod authorizer;
143pub mod header;
144mod policy;
145mod reporter;
146
147/// Layer that applies [SecFetch] which validates request against CSRF attacks
148pub struct SecFetchLayer<A = NoopAuthorizer, R = NoopReporter> {
149    enforce: bool,
150    policy: Policy,
151    authorizer: Arc<A>,
152    reporter: Arc<R>,
153}
154
155impl<A, R> Clone for SecFetchLayer<A, R> {
156    fn clone(&self) -> Self {
157        Self {
158            enforce: self.enforce,
159            policy: self.policy,
160            authorizer: self.authorizer.clone(),
161            reporter: self.reporter.clone(),
162        }
163    }
164}
165
166impl Default for SecFetchLayer {
167    fn default() -> Self {
168        Self {
169            enforce: true,
170            policy: Policy::default(),
171            authorizer: Arc::new(NoopAuthorizer),
172            reporter: Arc::new(NoopReporter),
173        }
174    }
175}
176
177impl SecFetchLayer {
178    pub fn new<F>(make_policy: F) -> Self
179    where
180        F: FnOnce(&mut PolicyBuilder),
181    {
182        let mut builder = PolicyBuilder::new();
183        make_policy(&mut builder);
184        let policy = builder.build();
185        Self {
186            policy,
187            ..Default::default()
188        }
189    }
190}
191
192impl<OldA, OldR> SecFetchLayer<OldA, OldR> {
193    pub fn allowing(
194        self,
195        paths: impl Into<Arc<[&'static str]>>,
196    ) -> SecFetchLayer<PathAuthorizer, OldR> {
197        self.with_authorizer(PathAuthorizer::new(paths))
198    }
199
200    pub fn no_enforce(mut self) -> Self {
201        self.enforce = false;
202        self
203    }
204
205    pub fn with_authorizer<A: SecFetchAuthorizer>(self, authorizer: A) -> SecFetchLayer<A, OldR> {
206        SecFetchLayer {
207            enforce: self.enforce,
208            policy: self.policy,
209            authorizer: Arc::from(authorizer),
210            reporter: self.reporter,
211        }
212    }
213
214    pub fn with_reporter<R: SecFetchReporter>(self, reporter: R) -> SecFetchLayer<OldA, R> {
215        SecFetchLayer {
216            enforce: self.enforce,
217            policy: self.policy,
218            authorizer: self.authorizer,
219            reporter: Arc::from(reporter),
220        }
221    }
222}
223
224impl<A, R, S> Layer<S> for SecFetchLayer<A, R> {
225    type Service = SecFetch<A, R, S>;
226
227    fn layer(&self, inner: S) -> Self::Service {
228        SecFetch {
229            enforce: self.enforce,
230            policy: self.policy,
231            authorizer: self.authorizer.clone(),
232            reporter: self.reporter.clone(),
233            inner,
234        }
235    }
236}
237
238/// Middleware protecting against CSRF attacks
239pub struct SecFetch<A, R, S> {
240    enforce: bool,
241    policy: Policy,
242    authorizer: Arc<A>,
243    reporter: Arc<R>,
244    inner: S,
245}
246
247impl<A, R, S> Clone for SecFetch<A, R, S>
248where
249    S: Clone,
250{
251    fn clone(&self) -> Self {
252        Self {
253            enforce: self.enforce,
254            policy: self.policy,
255            authorizer: self.authorizer.clone(),
256            reporter: self.reporter.clone(),
257            inner: self.inner.clone(),
258        }
259    }
260}
261
262impl<A, R, ReqB, ResB, S> Service<http::Request<ReqB>> for SecFetch<A, R, S>
263where
264    A: SecFetchAuthorizer,
265    R: SecFetchReporter,
266    S: Service<http::Request<ReqB>, Response = http::Response<ResB>>,
267    ResB: Default,
268{
269    type Response = S::Response;
270
271    type Error = S::Error;
272
273    type Future = Either<S::Future, Ready<Result<Self::Response, Self::Error>>>;
274
275    #[inline]
276    fn poll_ready(
277        &mut self,
278        cx: &mut std::task::Context<'_>,
279    ) -> std::task::Poll<Result<(), Self::Error>> {
280        self.inner.poll_ready(cx)
281    }
282
283    fn call(&mut self, request: http::Request<ReqB>) -> Self::Future {
284        let mut allow = |request| Either::Left(self.inner.call(request));
285        let deny = || {
286            Either::Right(future::ready(Ok(http::Response::builder()
287                .status(StatusCode::FORBIDDEN)
288                .body(ResB::default())
289                .expect("valid response"))))
290        };
291
292        match self.authorizer.authorize(&request) {
293            AuthorizationDecision::Allowed => return allow(request),
294            AuthorizationDecision::Denied => return deny(),
295            AuthorizationDecision::Continue => {}
296        }
297
298        if self.policy.allow(&request) {
299            return allow(request);
300        }
301
302        self.reporter.on_request_denied(&request);
303
304        // the request was denied, but we are not enforcing it
305        // we report the failure and let the request continue
306        if !self.enforce {
307            return allow(request);
308        }
309
310        deny()
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use std::sync::atomic::{AtomicBool, Ordering};
317
318    use assert2::{check, let_assert};
319    use http::Method;
320    use tower::ServiceExt;
321    use tower_test::mock;
322
323    use super::*;
324
325    macro_rules! request {
326        (site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
327            request!(::http::Method::GET, "/", site => $site, mode => $mode, dest => $dest)
328        };
329
330        ($path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
331            request!(::http::Method::GET, $path, site => $site, mode => $mode, dest => $dest)
332        };
333
334        ($method:expr, $path:expr, site => $site:expr, mode => $mode:expr, dest => $dest:expr) => {
335            ::http::Request::builder()
336                .method($method)
337                .uri(format!("https://example.com{}", $path))
338                .header(header::SEC_FETCH_SITE, $site)
339                .header(header::SEC_FETCH_MODE, $mode)
340                .header(header::SEC_FETCH_DEST, $dest)
341                .body(())
342                .unwrap()
343        };
344    }
345
346    macro_rules! assert_request {
347        ($req:expr, $assert_resp:expr) => {
348            assert_request!($req, $assert_resp, SecFetchLayer::default())
349        };
350
351        ($req:expr, $assert_resp:expr, $layer:expr) => {
352            let (service, mut handler) =
353                mock::spawn_layer::<http::Request<()>, http::Response<()>, _>($layer);
354
355            tokio::spawn(async move {
356                let_assert!(Some((_, send)) = handler.next_request().await);
357                send.send_response(http::Response::new(()));
358            });
359
360            let response = service.into_inner().oneshot($req).await.unwrap();
361
362            ($assert_resp)(response);
363        };
364    }
365
366    #[tokio::test]
367    async fn it_allows_requests_missing_the_fetch_metadata() {
368        let request = http::Request::new(());
369
370        assert_request!(request, |response: http::Response<()>| {
371            check!(response.status().is_success());
372        });
373    }
374
375    #[tokio::test]
376    async fn it_rejects_requests_missing_the_fetch_metadata_if_configured() {
377        let layer = SecFetchLayer::new(|policy| {
378            policy.reject_missing_metadata();
379        });
380        let request = http::Request::new(());
381
382        assert_request!(
383            request,
384            |response: http::Response<()>| {
385                check!(response.status() == StatusCode::FORBIDDEN);
386            },
387            layer
388        );
389    }
390
391    #[tokio::test]
392    async fn it_allows_same_site_requests() {
393        let request = request!(site => "same-site", mode => "navigate", dest => "document");
394
395        assert_request!(request, |response: http::Response<()>| {
396            check!(response.status().is_success());
397        });
398    }
399
400    #[tokio::test]
401    async fn it_disallows_cross_origin_requests() {
402        let request = request!(site => "cross-site", mode => "cors", dest => "empty");
403
404        assert_request!(request, |response: http::Response<()>| {
405            check!(response.status() == StatusCode::FORBIDDEN);
406        });
407    }
408
409    #[tokio::test]
410    async fn it_allows_cross_origin_requests_safe_methods_if_configured() {
411        let layer = SecFetchLayer::new(|policy| {
412            policy.allow_safe_methods();
413        });
414        let request =
415            request!(Method::GET, "/", site => "cross-site", mode => "cors", dest => "empty");
416
417        assert_request!(
418            request,
419            |response: http::Response<()>| {
420                check!(response.status().is_success());
421            },
422            layer
423        );
424    }
425
426    #[tokio::test]
427    async fn it_allows_navigation_requests() {
428        let request = request!(site => "cross-site", mode => "navigate", dest => "document");
429
430        assert_request!(request, |response: http::Response<()>| {
431            check!(response.status().is_success());
432        });
433    }
434
435    #[tokio::test]
436    async fn it_ignores_explicitely_authorized_requests() {
437        let layer = SecFetchLayer::default().allowing(["/allowed"]);
438        let request = request!("/allowed", site => "cross-site", mode => "cors", dest => "empty");
439
440        assert_request!(
441            request,
442            |response: http::Response<()>| {
443                check!(response.status().is_success());
444            },
445            layer
446        );
447    }
448
449    #[tokio::test]
450    async fn it_allows_denied_requests_if_enforcement_is_turned_off() {
451        let layer = SecFetchLayer::default().no_enforce();
452        let request = request!(site => "cross-site", mode => "cors", dest => "empty");
453
454        assert_request!(
455            request,
456            |response: http::Response<()>| {
457                check!(response.status().is_success());
458            },
459            layer
460        );
461    }
462
463    #[derive(Default)]
464    struct TestReporter {
465        called: AtomicBool,
466    }
467
468    impl SecFetchReporter for TestReporter {
469        fn on_request_denied<B>(&self, _: &http::Request<B>) {
470            self.called.store(true, Ordering::SeqCst);
471        }
472    }
473
474    #[tokio::test]
475    async fn it_reports_a_denied_requests() {
476        let reporter = Arc::new(TestReporter::default());
477        let layer = SecFetchLayer::default().with_reporter(reporter.clone());
478        let request = request!(site => "cross-site", mode => "cors", dest => "empty");
479
480        assert_request!(
481            request,
482            |response: http::Response<()>| {
483                check!(response.status() == StatusCode::FORBIDDEN);
484            },
485            layer
486        );
487
488        let called = reporter.called.load(Ordering::SeqCst);
489        check!(
490            called,
491            "reporter was not called despite the request being rejected"
492        );
493    }
494}