tower_sec_fetch/
policy.rs

1use http::{HeaderValue, Method};
2
3use crate::header;
4
5#[derive(Copy, Clone, Default)]
6pub struct Policy {
7    reject_missing_metadata: bool,
8    allow_safe_methods: bool,
9}
10
11impl Policy {
12    // Resource Isolation Policy
13    // Implemented following https://web.dev/articles/fetch-metadata
14    pub fn allow<B>(&self, request: &http::Request<B>) -> bool {
15        if self.allow_safe_methods
16            && method_in(
17                request.method(),
18                [Method::GET, Method::HEAD, Method::OPTIONS],
19            )
20        {
21            #[cfg(feature = "tracing")]
22            tracing::trace!(
23                method = %request.method(),
24                path = request.uri().path(),
25                "request uses a safe method: allowed",
26            );
27
28            return true;
29        }
30
31        let sec_fetch_site = request.headers().get(header::SEC_FETCH_SITE);
32        let sec_fetch_mode = request.headers().get(header::SEC_FETCH_MODE);
33        let sec_fetch_dest = request.headers().get(header::SEC_FETCH_DEST);
34
35        let sec_fetch = zip3(sec_fetch_site, sec_fetch_mode, sec_fetch_dest);
36
37        let Some((sec_fetch_site, sec_fetch_mode, sec_fetch_dest)) = sec_fetch else {
38            #[cfg(feature = "tracing")]
39            tracing::trace!(
40                method = %request.method(),
41                path = request.uri().path(),
42                "request is missing fetch metadata: {}",
43                if self.reject_missing_metadata { "denied" } else { "allowed" },
44            );
45
46            // Fetch metadata headers are missing.
47            // Either the request doesn't come from a browser, or the browser is too old.
48            return !self.reject_missing_metadata;
49        };
50
51        if header_in(sec_fetch_site, ["same-origin", "same-site", "none"]) {
52            #[cfg(feature = "tracing")]
53            tracing::trace!(
54                method = %request.method(),
55                path = request.uri().path(),
56                "request is same-site or user initiated: allowed",
57            );
58
59            // request is same-site or user initiated
60            return true;
61        }
62
63        if sec_fetch_mode == "navigate"
64            && request.method() == Method::GET
65            && header_in(sec_fetch_dest, ["empty", "document"])
66        {
67            #[cfg(feature = "tracing")]
68            tracing::trace!(
69                method = %request.method(),
70                path = request.uri().path(),
71                "request is a non-embed navigation: allowed",
72            );
73
74            // request is a regular navigation event and is not being embedded
75            return true;
76        }
77
78        #[cfg(feature = "tracing")]
79        tracing::trace!(
80            method = %request.method(),
81            path = request.uri().path(),
82            "request denied",
83        );
84
85        // request is denied
86        false
87    }
88}
89
90/// Allows customizing the behaviour of the default evaluation policy
91pub struct PolicyBuilder {
92    reject_missing_metadata: bool,
93    allow_safe_methods: bool,
94}
95
96impl PolicyBuilder {
97    pub(crate) fn new() -> Self {
98        Self {
99            reject_missing_metadata: false,
100            allow_safe_methods: false,
101        }
102    }
103
104    /// Reject requests that do not provide all three Fetch Metadata headers:
105    /// `sec-fetch-site`, `sec-fetch-mode`, `sec-fetch-dest`
106    pub fn reject_missing_metadata(&mut self) -> &mut Self {
107        self.reject_missing_metadata = true;
108        self
109    }
110
111    /// Allow safe requests (`GET`, `HEAD`, and `OPTIONS`) regardless of their origin
112    pub fn allow_safe_methods(&mut self) -> &mut Self {
113        self.allow_safe_methods = true;
114        self
115    }
116
117    pub(crate) fn build(self) -> Policy {
118        Policy {
119            reject_missing_metadata: self.reject_missing_metadata,
120            allow_safe_methods: self.allow_safe_methods,
121        }
122    }
123}
124
125fn zip3<T1, T2, T3>(a: Option<T1>, b: Option<T2>, c: Option<T3>) -> Option<(T1, T2, T3)> {
126    match (a, b, c) {
127        (Some(a), Some(b), Some(c)) => Some((a, b, c)),
128        _ => None,
129    }
130}
131
132fn header_in(header: &HeaderValue, values: impl IntoIterator<Item = &'static str>) -> bool {
133    values
134        .into_iter()
135        .map(HeaderValue::from_static)
136        .any(|value| value == header)
137}
138
139fn method_in(method: &Method, values: impl IntoIterator<Item = Method>) -> bool {
140    values.into_iter().any(|value| value == method)
141}