Skip to main content

tower_http/csrf/
mod.rs

1//! Modern protection against [cross-site request forgery] (CSRF) attacks.
2//!
3//! This middleware implements the CSRF protection scheme [introduced in Go 1.25][go]
4//! and described in [Filippo Valsorda's blog post][filippo]. It relies on the
5//! [`Sec-Fetch-Site`] and [`Origin`] request headers and requires no
6//! per-request token state.
7//!
8//! Requests are allowed if any of the following hold:
9//!
10//! 1. The method is `GET`, `HEAD`, or `OPTIONS`.
11//! 2. The `Origin` header byte-for-byte matches an allow-listed trusted origin.
12//! 3. `Sec-Fetch-Site` is `same-origin` or `none`.
13//! 4. Neither `Sec-Fetch-Site` nor `Origin` is present.
14//! 5. The `Origin`'s authority (host and any port) matches the request's effective
15//!    host byte-for-byte (the request-target authority if present, else `Host`).
16//!
17//! Rejected requests receive a `403 Forbidden` response. The originating
18//! [`ProtectionError`] is attached to the response's extensions — on every
19//! rejection, including those from a custom builder — so surrounding layers can
20//! distinguish explicit cross-origin rejections from conservative fallback
21//! rejections (e.g. requests from old browsers without `Sec-Fetch-Site`). Use
22//! [`CsrfLayer::with_rejection_response`](CsrfLayer::with_rejection_response)
23//! to replace the rejection response with a custom builder.
24//!
25//! # Deployment caveat
26//!
27//! The middleware trusts whatever `Origin` and `Host` reach it. Reverse proxies
28//! and load balancers that rewrite `Host` (e.g. to an internal hostname) or
29//! strip `Origin` silently degrade the protection: the `Origin`/`Host`
30//! fallback can no longer match, and `Sec-Fetch-Site` becomes the only
31//! remaining line of defense. Configure intermediaries to forward both headers
32//! unchanged.
33//!
34//! # Example
35//!
36//! ```
37//! use bytes::Bytes;
38//! use http::{Request, Response, StatusCode};
39//! use http_body_util::Full;
40//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
41//! use tower_http::csrf::CsrfLayer;
42//!
43//! async fn handle(_: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
44//!     Ok(Response::new(Full::default()))
45//! }
46//!
47//! # #[tokio::main]
48//! # async fn main() -> Result<(), BoxError> {
49//! let layer = CsrfLayer::new()
50//!     .add_trusted_origin("https://example.com")?;
51//!
52//! let mut service = ServiceBuilder::new()
53//!     .layer(layer)
54//!     .service_fn(handle);
55//!
56//! // Safe methods always pass.
57//! let request = Request::builder()
58//!     .method("GET")
59//!     .uri("/")
60//!     .body(Full::default())
61//!     .unwrap();
62//!
63//! let response = service.ready().await?.call(request).await?;
64//!
65//! assert_eq!(response.status(), StatusCode::OK);
66//!
67//! // Cross-site POSTs are blocked.
68//! let request = Request::builder()
69//!     .method("POST")
70//!     .uri("/")
71//!     .header("host", "example.com")
72//!     .header("sec-fetch-site", "cross-site")
73//!     .body(Full::default())
74//!     .unwrap();
75//!
76//! let response = service.ready().await?.call(request).await?;
77//!
78//! assert_eq!(response.status(), StatusCode::FORBIDDEN);
79//!
80//! # Ok(())
81//! # }
82//! ```
83//!
84//! [cross-site request forgery]: https://developer.mozilla.org/en-US/docs/Glossary/CSRF
85//! [filippo]: https://words.filippo.io/csrf/
86//! [go]: https://pkg.go.dev/net/http#CrossOriginProtection
87//! [`Sec-Fetch-Site`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site
88//! [`Origin`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin
89
90use std::collections::HashSet;
91use std::fmt::{self, Debug, Formatter};
92use std::sync::Arc;
93
94use http::{Method, Uri};
95
96mod future;
97mod layer;
98mod response;
99mod service;
100mod url;
101
102pub use self::future::ResponseFuture;
103pub use self::layer::CsrfLayer;
104pub use self::response::{DefaultResponseForProtectionError, ResponseForProtectionError};
105pub use self::service::Csrf;
106
107/// Errors that can occur while configuring [`CsrfLayer`].
108#[derive(Clone, Debug, PartialEq)]
109#[non_exhaustive]
110pub enum ConfigError {
111    /// The origin string could not be parsed as a URI.
112    InvalidOriginUrl {
113        /// The offending origin string.
114        origin: String,
115        /// The parser error message.
116        message: String,
117    },
118
119    /// An origin URL containing a path, query, or fragment was added as a
120    /// trusted origin.
121    InvalidOriginUrlComponents {
122        /// The offending origin string.
123        origin: String,
124    },
125
126    /// An origin with a scheme other than `http` or `https` (e.g. `file://`,
127    /// `mailto:`, or a bare host with no scheme) was added as a trusted
128    /// origin. Such origins can never match a browser-supplied request
129    /// `Origin`.
130    OpaqueOrigin {
131        /// The offending origin string.
132        origin: String,
133    },
134
135    /// A trusted origin contained non-ASCII characters. Browsers send IDN
136    /// hostnames in punycode form, so the configured value must use the
137    /// punycode form (e.g. `xn--exmple-cua.com`) to ever match.
138    NonAsciiHostname {
139        /// The offending origin string.
140        origin: String,
141    },
142}
143
144impl fmt::Display for ConfigError {
145    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
146        match self {
147            ConfigError::InvalidOriginUrl { origin, message } => {
148                write!(f, "invalid origin {origin:?}: {message}")
149            }
150            ConfigError::InvalidOriginUrlComponents { origin } => write!(
151                f,
152                "invalid origin {origin:?}: path, query, and fragment are not allowed"
153            ),
154            ConfigError::OpaqueOrigin { origin } => write!(
155                f,
156                "invalid origin {origin:?}: scheme must be http or https"
157            ),
158            ConfigError::NonAsciiHostname { origin } => write!(
159                f,
160                "invalid origin {origin:?}: non-ASCII hostnames must be supplied in punycode (xn--…)"
161            ),
162        }
163    }
164}
165
166impl std::error::Error for ConfigError {}
167
168/// Reason a request was rejected by [`Csrf`].
169///
170/// Retrieve the category with [`ProtectionError::kind`]. [`Csrf`] attaches it to
171/// every `403 Forbidden` rejection response's extensions so surrounding layers
172/// can distinguish explicit cross-origin rejections from conservative fallback
173/// rejections.
174///
175/// This is an opaque struct rather than an enum so future variants can carry
176/// additional context without a breaking change; match on [`kind`] instead.
177///
178/// [`kind`]: ProtectionError::kind
179#[derive(Clone, Debug)]
180pub struct ProtectionError {
181    kind: ProtectionErrorKind,
182}
183
184impl ProtectionError {
185    pub(crate) fn new(kind: ProtectionErrorKind) -> Self {
186        Self { kind }
187    }
188
189    /// The category of rejection.
190    pub fn kind(&self) -> ProtectionErrorKind {
191        self.kind
192    }
193}
194
195impl fmt::Display for ProtectionError {
196    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
197        match self.kind {
198            ProtectionErrorKind::CrossOriginRequest => f.write_str("Cross-Origin request detected"),
199            ProtectionErrorKind::CrossOriginRequestFromOldBrowser => {
200                f.write_str("Cross-Origin request from old browser detected")
201            }
202        }
203    }
204}
205
206impl std::error::Error for ProtectionError {}
207
208/// The category of a [`ProtectionError`].
209#[derive(Clone, Copy, Debug, PartialEq, Eq)]
210#[non_exhaustive]
211pub enum ProtectionErrorKind {
212    /// A cross-origin request was detected via `Sec-Fetch-Site`.
213    CrossOriginRequest,
214
215    /// A request without `Sec-Fetch-Site` failed the `Origin`/`Host` fallback
216    /// check. Modern browsers always send `Sec-Fetch-Site`, so this typically
217    /// means the request came from an old browser or non-browser client.
218    CrossOriginRequestFromOldBrowser,
219}
220
221type BypassFn = dyn Fn(&Method, &Uri) -> bool + Send + Sync + 'static;
222
223struct DebugFn;
224
225impl Debug for DebugFn {
226    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
227        f.write_str("<fn>")
228    }
229}
230
231#[derive(Clone, Default)]
232struct Origins(Arc<HashSet<Vec<u8>>>);
233
234impl Origins {
235    fn contains(&self, origin: &[u8]) -> bool {
236        self.0.contains(origin)
237    }
238
239    fn insert(&mut self, origin: impl Into<Vec<u8>>) {
240        Arc::make_mut(&mut self.0).insert(origin.into());
241    }
242}
243
244impl Debug for Origins {
245    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
246        // render trusted origins as utf-h strings
247        write!(f, "Origins(")?;
248        f.debug_set()
249            .entries(self.0.iter().map(|o| String::from_utf8_lossy(o)))
250            .finish()?;
251        write!(f, ")")
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use std::convert::Infallible;
258
259    use http::{Request, Response, StatusCode};
260    use tower::{service_fn, ServiceExt};
261    use tower_layer::Layer;
262
263    use super::*;
264    use crate::test_helpers::{to_bytes, Body};
265
266    impl PartialEq for super::ProtectionError {
267        fn eq(&self, other: &Self) -> bool {
268            self.kind == other.kind
269        }
270    }
271
272    fn echo_service() -> impl tower::Service<
273        Request<Body>,
274        Response = Response<Body>,
275        Error = Infallible,
276        Future = impl std::future::Future<Output = Result<Response<Body>, Infallible>>,
277    > + Clone {
278        service_fn(|req: Request<Body>| async move {
279            let body: Body = match req.uri().path() {
280                "/foo" => "foo".into(),
281                "/bar" => "bar".into(),
282                _ => Body::empty(),
283            };
284            Ok::<_, Infallible>(Response::new(body))
285        })
286    }
287
288    #[tokio::test]
289    async fn test_service_allows_safe_method() {
290        let svc = CsrfLayer::new()
291            .add_trusted_origin("https://example.com")
292            .unwrap()
293            .layer(echo_service());
294
295        let req = Request::builder()
296            .method("GET")
297            .uri("/foo")
298            .body(Body::empty())
299            .unwrap();
300
301        let res = svc.oneshot(req).await.unwrap();
302
303        assert_eq!(res.status(), StatusCode::OK);
304
305        let body = to_bytes(res.into_body()).await.unwrap();
306        assert_eq!(&body[..], b"foo");
307    }
308
309    #[tokio::test]
310    async fn test_service_allows_post_from_trusted_origin() {
311        let svc = CsrfLayer::new()
312            .add_trusted_origin("https://example.com")
313            .unwrap()
314            .layer(echo_service());
315
316        let req = Request::builder()
317            .method("POST")
318            .uri("/bar")
319            .header("origin", "https://example.com")
320            .body(Body::empty())
321            .unwrap();
322
323        let res = svc.oneshot(req).await.unwrap();
324
325        assert_eq!(res.status(), StatusCode::OK);
326
327        let body = to_bytes(res.into_body()).await.unwrap();
328        assert_eq!(&body[..], b"bar");
329    }
330
331    #[tokio::test]
332    async fn test_service_rejects_post_from_untrusted_origin() {
333        let svc = CsrfLayer::new()
334            .add_trusted_origin("https://example.com")
335            .unwrap()
336            .layer(echo_service());
337
338        let req = Request::builder()
339            .method("POST")
340            .uri("/bar")
341            .header("origin", "https://malicious.example")
342            .body(Body::empty())
343            .unwrap();
344
345        let res = svc.oneshot(req).await.unwrap();
346
347        assert_eq!(res.status(), StatusCode::FORBIDDEN);
348        assert_eq!(
349            res.extensions().get::<ProtectionError>(),
350            Some(&ProtectionError::new(
351                ProtectionErrorKind::CrossOriginRequestFromOldBrowser
352            )),
353        );
354    }
355
356    #[tokio::test]
357    async fn test_service_uses_custom_rejection_response() {
358        let svc = CsrfLayer::new()
359            .with_rejection_response(|_err: ProtectionError| {
360                let mut res = Response::new(Body::from("denied"));
361                *res.status_mut() = StatusCode::IM_A_TEAPOT;
362                res
363            })
364            .layer(echo_service());
365
366        let req = Request::builder()
367            .method("POST")
368            .uri("/bar")
369            .header("origin", "https://malicious.example")
370            .body(Body::empty())
371            .unwrap();
372
373        let res = svc.oneshot(req).await.unwrap();
374
375        assert_eq!(res.status(), StatusCode::IM_A_TEAPOT);
376        assert_ne!(res.status(), StatusCode::OK);
377        // The middleware attaches the error even though a custom builder
378        // produced the response.
379        assert_eq!(
380            res.extensions().get::<ProtectionError>(),
381            Some(&ProtectionError::new(
382                ProtectionErrorKind::CrossOriginRequestFromOldBrowser
383            )),
384        );
385
386        let body = to_bytes(res.into_body()).await.unwrap();
387        assert_eq!(&body[..], b"denied");
388    }
389
390    #[tokio::test]
391    async fn test_service_custom_rejection_response_not_invoked_when_allowed() {
392        let svc = CsrfLayer::new()
393            .add_trusted_origin("https://example.com")
394            .unwrap()
395            .with_rejection_response(|_err: ProtectionError| {
396                let mut res = Response::new(Body::from("denied"));
397                *res.status_mut() = StatusCode::IM_A_TEAPOT;
398                res
399            })
400            .layer(echo_service());
401
402        let req = Request::builder()
403            .method("POST")
404            .uri("/bar")
405            .header("origin", "https://example.com")
406            .body(Body::empty())
407            .unwrap();
408
409        let res = svc.oneshot(req).await.unwrap();
410
411        assert_eq!(res.status(), StatusCode::OK);
412        assert_ne!(res.status(), StatusCode::IM_A_TEAPOT);
413        assert!(res.extensions().get::<ProtectionError>().is_none());
414
415        let body = to_bytes(res.into_body()).await.unwrap();
416        assert_eq!(&body[..], b"bar");
417    }
418
419    #[test]
420    fn test_layer_add_trusted_origin() {
421        // Smoke check that the layer threads parse_origin's Ok and Err
422        // through; the full validation matrix lives in url.rs.
423        assert!(CsrfLayer::new()
424            .add_trusted_origin("https://example.com")
425            .is_ok());
426        assert!(matches!(
427            CsrfLayer::new().add_trusted_origin("not a valid url"),
428            Err(ConfigError::InvalidOriginUrl { .. })
429        ));
430    }
431
432    #[test]
433    fn test_middleware_bypass() {
434        let layer = CsrfLayer::new()
435            .with_insecure_bypass(|_method, uri| -> bool { uri.path() == "/bypass" });
436
437        let middleware = layer.layer(());
438
439        struct Test {
440            name: &'static str,
441            path: &'static str,
442            sec_fetch_site: Option<&'static str>,
443            result: Result<(), ProtectionError>,
444        }
445
446        let tests = [
447            Test {
448                name: "bypass path without sec-fetch-site",
449                path: "/bypass",
450                sec_fetch_site: None,
451                result: Ok(()),
452            },
453            Test {
454                name: "bypass path with cross-site",
455                path: "/bypass",
456                sec_fetch_site: Some("cross-site"),
457                result: Ok(()),
458            },
459            Test {
460                name: "non-bypass path without sec-fetch-site",
461                path: "/api",
462                sec_fetch_site: None,
463                result: Err(ProtectionError::new(
464                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
465                )),
466            },
467            Test {
468                name: "non-bypass path with cross-site",
469                path: "/api",
470                sec_fetch_site: Some("cross-site"),
471                result: Err(ProtectionError::new(
472                    ProtectionErrorKind::CrossOriginRequest,
473                )),
474            },
475        ];
476
477        for test in tests {
478            let mut req = Request::builder()
479                .method("POST")
480                .header("host", "example.com")
481                .header("origin", "https://attacker.example")
482                .uri(format!("https://example.com{}", test.path));
483
484            if let Some(sec_fetch_site) = test.sec_fetch_site {
485                req = req.header("sec-fetch-site", sec_fetch_site);
486            }
487
488            let req = req.body(()).unwrap();
489
490            assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
491        }
492    }
493
494    #[test]
495    fn test_middleware_bypass_applies_when_origin_unparseable() {
496        let middleware = CsrfLayer::new()
497            .with_insecure_bypass(|_method, uri| uri.path() == "/bypass")
498            .layer(());
499
500        let req = Request::builder()
501            .method("POST")
502            .uri("https://example.com/bypass")
503            .header("host", "example.com")
504            .header(
505                "origin",
506                http::HeaderValue::from_bytes(&[0xFF, 0xFE]).unwrap(),
507            )
508            .body(())
509            .unwrap();
510
511        assert_eq!(middleware.verify(&req), Ok(()));
512    }
513
514    #[test]
515    fn test_middleware_debug_trait() {
516        let layer = CsrfLayer::new();
517
518        let middleware = layer
519            .clone()
520            .with_insecure_bypass(|method, uri| method == Method::POST && uri.path() == "/bypass")
521            .layer(());
522
523        assert_eq!(
524            format!("{:?}", middleware),
525            "Csrf { inner: (), insecure_bypass: Some(<fn>), trusted_origins: Origins({}), rejection_response: <fn> }"
526        );
527
528        let middleware = layer.layer(());
529
530        assert_eq!(
531            format!("{:?}", middleware),
532            "Csrf { inner: (), insecure_bypass: None, trusted_origins: Origins({}), rejection_response: <fn> }"
533        );
534    }
535
536    #[test]
537    fn test_middleware_origin_host_port_match() {
538        let middleware: Csrf<()> = Default::default();
539
540        struct Test {
541            name: &'static str,
542            uri: &'static str,
543            host: Option<&'static str>,
544            origin: &'static str,
545            result: Result<(), ProtectionError>,
546        }
547
548        let tests = [
549            Test {
550                name: "default port both sides",
551                uri: "/",
552                host: Some("example.com"),
553                origin: "https://example.com",
554                result: Ok(()),
555            },
556            Test {
557                name: "same non-default port both sides",
558                uri: "/",
559                host: Some("example.com:8443"),
560                origin: "https://example.com:8443",
561                result: Ok(()),
562            },
563            Test {
564                name: "explicit default port both sides",
565                uri: "/",
566                host: Some("example.com:443"),
567                origin: "https://example.com:443",
568                result: Ok(()),
569            },
570            Test {
571                name: "mismatched non-default ports",
572                uri: "/",
573                host: Some("example.com:8443"),
574                origin: "https://example.com:8444",
575                result: Err(ProtectionError::new(
576                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
577                )),
578            },
579            Test {
580                // Strict byte match: an explicit default port does not equal an
581                // implicit one (the reference does not normalize ports).
582                name: "origin has explicit default, host implicit",
583                uri: "/",
584                host: Some("example.com"),
585                origin: "https://example.com:443",
586                result: Err(ProtectionError::new(
587                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
588                )),
589            },
590            Test {
591                name: "host has explicit default, origin implicit",
592                uri: "/",
593                host: Some("example.com:443"),
594                origin: "https://example.com",
595                result: Err(ProtectionError::new(
596                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
597                )),
598            },
599            Test {
600                name: "host implicit, origin explicit non-default",
601                uri: "/",
602                host: Some("example.com"),
603                origin: "https://example.com:8443",
604                result: Err(ProtectionError::new(
605                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
606                )),
607            },
608            Test {
609                name: "missing host, uri authority implicit, origin explicit non-default",
610                uri: "https://example.com/path",
611                host: None,
612                origin: "https://example.com:8443",
613                result: Err(ProtectionError::new(
614                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
615                )),
616            },
617            Test {
618                // No request-target authority, so the Host header is the effective
619                // host, compared verbatim — a malformed Host never matches an Origin.
620                name: "malformed host header compared verbatim",
621                uri: "/path",
622                host: Some("not a valid authority"),
623                origin: "https://example.com",
624                result: Err(ProtectionError::new(
625                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
626                )),
627            },
628            Test {
629                // RFC 7230 §5.3 / Go parity: the request-target authority is the
630                // effective host (Host header ignored); here it matches Origin.
631                name: "request-target authority wins over host header (match)",
632                uri: "https://example.com/path",
633                host: Some("other.example"),
634                origin: "https://example.com",
635                result: Ok(()),
636            },
637            Test {
638                // Security-relevant: Origin matches the Host header but not the
639                // winning request-target authority, so it stays cross-origin.
640                name: "origin matching host header but not authority is rejected",
641                uri: "https://example.com/path",
642                host: Some("other.example"),
643                origin: "https://other.example",
644                result: Err(ProtectionError::new(
645                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
646                )),
647            },
648            Test {
649                name: "missing host, uri carries authority (match)",
650                uri: "https://example.com/path",
651                host: None,
652                origin: "https://example.com",
653                result: Ok(()),
654            },
655            Test {
656                name: "missing host, uri authority mismatch",
657                uri: "https://other.example/path",
658                host: None,
659                origin: "https://example.com",
660                result: Err(ProtectionError::new(
661                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
662                )),
663            },
664            Test {
665                name: "missing host and no uri authority",
666                uri: "/path",
667                host: None,
668                origin: "https://example.com",
669                result: Err(ProtectionError::new(
670                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
671                )),
672            },
673            Test {
674                name: "scheme-less origin does not match host even if bytes agree",
675                uri: "/",
676                host: Some("example.com:8443"),
677                origin: "example.com:8443",
678                result: Err(ProtectionError::new(
679                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
680                )),
681            },
682            Test {
683                name: "non-http origin scheme does not enter host fallback",
684                uri: "/",
685                host: Some("example.com:8443"),
686                origin: "ftp://example.com:8443",
687                result: Err(ProtectionError::new(
688                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
689                )),
690            },
691        ];
692
693        for test in tests {
694            let mut req = Request::builder().method(Method::POST).uri(test.uri);
695
696            if let Some(host) = test.host {
697                req = req.header("host", host);
698            }
699
700            let req = req.header("origin", test.origin).body(()).unwrap();
701
702            assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
703        }
704    }
705
706    #[test]
707    fn test_middleware_sec_fetch_site() {
708        let middleware: Csrf<()> = Default::default();
709
710        const NON_DECODABLE: &[u8] = &[0xFF, 0xFE];
711        assert!(
712            http::HeaderValue::from_bytes(NON_DECODABLE)
713                .expect("NON_DECODABLE must be a valid HeaderValue")
714                .to_str()
715                .is_err(),
716            "NON_DECODABLE must fail HeaderValue::to_str()"
717        );
718
719        struct Test {
720            name: &'static str,
721            method: http::Method,
722            sec_fetch_site: Option<&'static [u8]>,
723            origin: Option<&'static [u8]>,
724            result: Result<(), ProtectionError>,
725        }
726
727        let tests = [
728            Test {
729                name: "same-origin allowed",
730                method: Method::GET,
731                sec_fetch_site: Some(b"same-origin"),
732                origin: None,
733                result: Ok(()),
734            },
735            Test {
736                name: "none allowed",
737                method: Method::POST,
738                sec_fetch_site: Some(b"none"),
739                origin: None,
740                result: Ok(()),
741            },
742            Test {
743                name: "cross-site blocked",
744                method: Method::POST,
745                sec_fetch_site: Some(b"cross-site"),
746                origin: None,
747                result: Err(ProtectionError::new(
748                    ProtectionErrorKind::CrossOriginRequest,
749                )),
750            },
751            Test {
752                name: "same-site blocked",
753                method: Method::POST,
754                sec_fetch_site: Some(b"same-site"),
755                origin: None,
756                result: Err(ProtectionError::new(
757                    ProtectionErrorKind::CrossOriginRequest,
758                )),
759            },
760            Test {
761                name: "no header with no origin",
762                method: Method::POST,
763                sec_fetch_site: None,
764                origin: None,
765                result: Ok(()),
766            },
767            Test {
768                name: "no header with matching origin",
769                method: Method::POST,
770                sec_fetch_site: None,
771                origin: Some(b"https://example.com"),
772                result: Ok(()),
773            },
774            Test {
775                name: "no header with mismatched origin",
776                method: Method::POST,
777                sec_fetch_site: None,
778                origin: Some(b"https://attacker.example"),
779                result: Err(ProtectionError::new(
780                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
781                )),
782            },
783            Test {
784                name: "no header with null origin",
785                method: Method::POST,
786                sec_fetch_site: None,
787                origin: Some(b"null"),
788                result: Err(ProtectionError::new(
789                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
790                )),
791            },
792            Test {
793                name: "GET allowed",
794                method: Method::GET,
795                sec_fetch_site: Some(b"cross-site"),
796                origin: None,
797                result: Ok(()),
798            },
799            Test {
800                name: "HEAD allowed",
801                method: Method::HEAD,
802                sec_fetch_site: Some(b"cross-site"),
803                origin: None,
804                result: Ok(()),
805            },
806            Test {
807                name: "OPTIONS allowed",
808                method: Method::OPTIONS,
809                sec_fetch_site: Some(b"cross-site"),
810                origin: None,
811                result: Ok(()),
812            },
813            Test {
814                name: "PUT blocked",
815                method: Method::PUT,
816                sec_fetch_site: Some(b"cross-site"),
817                origin: None,
818                result: Err(ProtectionError::new(
819                    ProtectionErrorKind::CrossOriginRequest,
820                )),
821            },
822            Test {
823                name: "non-decodable origin without sec-fetch-site rejected",
824                method: Method::POST,
825                sec_fetch_site: None,
826                origin: Some(NON_DECODABLE),
827                result: Err(ProtectionError::new(
828                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
829                )),
830            },
831            Test {
832                name: "non-decodable sec-fetch-site without origin rejected",
833                method: Method::POST,
834                sec_fetch_site: Some(NON_DECODABLE),
835                origin: None,
836                result: Err(ProtectionError::new(
837                    ProtectionErrorKind::CrossOriginRequest,
838                )),
839            },
840            Test {
841                name: "empty sec-fetch-site without origin allowed",
842                method: Method::POST,
843                sec_fetch_site: Some(b""),
844                origin: None,
845                result: Ok(()),
846            },
847            Test {
848                name: "empty origin without sec-fetch-site allowed",
849                method: Method::POST,
850                sec_fetch_site: None,
851                origin: Some(b""),
852                result: Ok(()),
853            },
854        ];
855
856        for test in tests {
857            let mut req = Request::builder()
858                .method(test.method)
859                .header("host", "example.com");
860
861            if let Some(sec_fetch_site) = test.sec_fetch_site {
862                req = req.header("sec-fetch-site", sec_fetch_site);
863            }
864
865            if let Some(origin) = test.origin {
866                req = req.header("origin", origin);
867            }
868
869            let req = req.body(()).unwrap();
870
871            assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
872        }
873    }
874
875    #[test]
876    fn test_middleware_trusted_origin_bypass() {
877        let layer = CsrfLayer::new()
878            .add_trusted_origin("https://trusted.example")
879            .unwrap();
880
881        let middleware = layer.layer(());
882
883        struct Test {
884            name: &'static str,
885            sec_fetch_site: Option<&'static str>,
886            origin: Option<&'static str>,
887            result: Result<(), ProtectionError>,
888        }
889
890        let tests = [
891            Test {
892                name: "trusted origin without sec-fetch-site",
893                origin: Some("https://trusted.example"),
894                sec_fetch_site: None,
895                result: Ok(()),
896            },
897            Test {
898                name: "trusted origin with cross-site",
899                origin: Some("https://trusted.example"),
900                sec_fetch_site: Some("cross-site"),
901                result: Ok(()),
902            },
903            Test {
904                name: "untrusted origin without sec-fetch-site",
905                origin: Some("https://attacker.example"),
906                sec_fetch_site: None,
907                result: Err(ProtectionError::new(
908                    ProtectionErrorKind::CrossOriginRequestFromOldBrowser,
909                )),
910            },
911            Test {
912                name: "untrusted origin with cross-site",
913                origin: Some("https://attacker.example"),
914                sec_fetch_site: Some("cross-site"),
915                result: Err(ProtectionError::new(
916                    ProtectionErrorKind::CrossOriginRequest,
917                )),
918            },
919        ];
920
921        for test in tests {
922            let mut req = Request::builder()
923                .method("POST")
924                .header("host", "example.com");
925
926            if let Some(sec_fetch_site) = test.sec_fetch_site {
927                req = req.header("sec-fetch-site", sec_fetch_site);
928            }
929
930            if let Some(origin) = test.origin {
931                req = req.header("origin", origin);
932            }
933
934            let req = req.body(()).unwrap();
935
936            assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
937        }
938    }
939
940    #[test]
941    fn test_middleware_trusted_origin_strict_byte_match() {
942        // Trusted origins are matched byte-for-byte against the request's Origin
943        // header (no canonicalization), mirroring the Go reference. Only an exact
944        // match is trusted; case- and port-form variants are not.
945        struct Test {
946            name: &'static str,
947            trusted: &'static str,
948            origin: &'static str,
949            result: Result<(), ProtectionError>,
950        }
951
952        let tests = [
953            Test {
954                name: "exact match trusted",
955                trusted: "https://example.com",
956                origin: "https://example.com",
957                result: Ok(()),
958            },
959            Test {
960                name: "exact match with non-default port",
961                trusted: "https://example.com:8443",
962                origin: "https://example.com:8443",
963                result: Ok(()),
964            },
965            Test {
966                name: "host case mismatch not trusted",
967                trusted: "https://Example.COM",
968                origin: "https://example.com",
969                result: Err(ProtectionError::new(
970                    ProtectionErrorKind::CrossOriginRequest,
971                )),
972            },
973            Test {
974                name: "explicit default port not trusted against bare origin",
975                trusted: "https://example.com:443",
976                origin: "https://example.com",
977                result: Err(ProtectionError::new(
978                    ProtectionErrorKind::CrossOriginRequest,
979                )),
980            },
981            Test {
982                name: "bare trusted not matched by explicit-default-port origin",
983                trusted: "https://example.com",
984                origin: "https://example.com:443",
985                result: Err(ProtectionError::new(
986                    ProtectionErrorKind::CrossOriginRequest,
987                )),
988            },
989        ];
990
991        for test in tests {
992            let middleware = CsrfLayer::new()
993                .add_trusted_origin(test.trusted)
994                .unwrap_or_else(|e| panic!("{}: add_trusted_origin failed: {e}", test.name))
995                .layer(());
996
997            let req = Request::builder()
998                .method("POST")
999                .header("host", "other.example")
1000                .header("origin", test.origin)
1001                .header("sec-fetch-site", "cross-site")
1002                .body(())
1003                .unwrap();
1004
1005            assert_eq!(middleware.verify(&req), test.result, "{}", test.name);
1006        }
1007    }
1008}