Skip to main content

tower_http/csrf/
layer.rs

1use std::fmt::{self, Debug, Formatter};
2use std::sync::Arc;
3
4use http::{Method, Uri};
5use tower_layer::Layer;
6
7use super::service::Csrf;
8use super::url::UriExt;
9use super::{BypassFn, ConfigError, DebugFn, DefaultResponseForProtectionError, Origins};
10
11/// Layer that applies the [`Csrf`] middleware.
12///
13/// See the [module docs](crate::csrf) for an example.
14#[derive(Clone)]
15#[must_use]
16pub struct CsrfLayer<T = DefaultResponseForProtectionError> {
17    insecure_bypass: Option<Arc<BypassFn>>,
18    rejection_response: T,
19    trusted_origins: Origins,
20}
21
22impl Default for CsrfLayer {
23    fn default() -> Self {
24        Self {
25            insecure_bypass: None,
26            rejection_response: DefaultResponseForProtectionError,
27            trusted_origins: Origins::default(),
28        }
29    }
30}
31
32impl<T> Debug for CsrfLayer<T> {
33    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
34        f.debug_struct("CsrfLayer")
35            .field(
36                "insecure_bypass",
37                &self.insecure_bypass.as_ref().map(|_| DebugFn),
38            )
39            .field("trusted_origins", &self.trusted_origins)
40            .field("rejection_response", &DebugFn)
41            .finish()
42    }
43}
44
45impl CsrfLayer {
46    /// Creates a new `CsrfLayer` with no trusted origins, no bypass, and the
47    /// default rejection response.
48    pub fn new() -> Self {
49        Self::default()
50    }
51}
52
53impl<T> CsrfLayer<T> {
54    /// Adds a trusted origin that allows all requests whose `Origin` header
55    /// matches the given value.
56    ///
57    /// The value is matched **byte-for-byte** against the request's `Origin`
58    /// header — there is no normalization (this mirrors the Go reference). It
59    /// must therefore be written exactly as a browser sends it:
60    ///
61    /// - form `scheme://host[:port]`, where `scheme` is `http` or `https`;
62    /// - the host lowercased (browsers lowercase it; IDN hosts must be given in
63    ///   punycode, e.g. `xn--exmple-cua.com`);
64    /// - **default ports omitted** — browsers drop `:80`/`:443`, so an explicit
65    ///   default port (e.g. `https://example.com:443`) will never match;
66    /// - **no trailing slash**, path, query, or fragment.
67    ///
68    /// Inputs that can't represent a browser `Origin` are rejected with a
69    /// [`ConfigError`]; inputs that parse but aren't in the canonical browser
70    /// form above are accepted but will silently never match.
71    ///
72    /// ```
73    /// # use tower_http::csrf::CsrfLayer;
74    /// // Matches `Origin: https://example.com`:
75    /// let layer = CsrfLayer::new().add_trusted_origin("https://example.com")?;
76    ///
77    /// // Accepted, but never matches a browser Origin (explicit default port):
78    /// let layer = CsrfLayer::new().add_trusted_origin("https://example.com:443")?;
79    /// # Ok::<_, tower_http::csrf::ConfigError>(())
80    /// ```
81    pub fn add_trusted_origin<S: AsRef<str>>(mut self, origin: S) -> Result<Self, ConfigError> {
82        let origin = origin.as_ref();
83
84        // validate the form; the origin is stored and matched verbatim.
85        Uri::parse_origin(origin)?;
86
87        #[cfg(feature = "tracing")]
88        tracing::debug!(origin = %origin, "added trusted origin");
89
90        self.trusted_origins.insert(origin.to_owned());
91
92        Ok(self)
93    }
94
95    /// Adds a bypass predicate that returns `true` for requests which should
96    /// skip CSRF protection.
97    ///
98    /// This is an escape hatch for endpoints that legitimately need to accept
99    /// cross-origin POSTs (e.g. webhook receivers). Bypassed endpoints must
100    /// have their own protection (signed payloads, authentication tokens,
101    /// etc.) — otherwise they are CSRF-vulnerable.
102    pub fn with_insecure_bypass<F>(mut self, predicate: F) -> Self
103    where
104        F: Fn(&Method, &Uri) -> bool + Send + Sync + 'static,
105    {
106        #[cfg(feature = "tracing")]
107        tracing::debug!("added insecure bypass");
108
109        self.insecure_bypass = Some(Arc::new(predicate));
110        self
111    }
112
113    /// Replaces the response builder used when a request is rejected.
114    ///
115    /// Accepts any type that implements [`ResponseForProtectionError`](super::ResponseForProtectionError),
116    /// including a `FnMut(ProtectionError) -> Response<B> + Clone` closure.
117    /// The default builder returns a `403 Forbidden` with an empty body.
118    /// Regardless of the builder, [`Csrf`](super::Csrf) attaches the
119    /// [`ProtectionError`](super::ProtectionError) to the response's extensions,
120    /// so a custom builder need not re-attach it.
121    pub fn with_rejection_response<R>(self, rejection_response: R) -> CsrfLayer<R>
122    where
123        R: Clone,
124    {
125        CsrfLayer {
126            insecure_bypass: self.insecure_bypass,
127            trusted_origins: self.trusted_origins,
128            rejection_response,
129        }
130    }
131}
132
133impl<S, T> Layer<S> for CsrfLayer<T>
134where
135    T: Clone,
136{
137    type Service = Csrf<S, T>;
138
139    fn layer(&self, inner: S) -> Self::Service {
140        Csrf::new(
141            inner,
142            self.insecure_bypass.clone(),
143            self.rejection_response.clone(),
144            self.trusted_origins.clone(),
145        )
146    }
147}