tower_sec_fetch/
policy.rs1use http::{HeaderValue, Method};
2
3use crate::header;
4
5#[derive(Copy, Clone, Default)]
6pub struct Policy {
7 reject_missing_metadata: bool,
8 allow_safe_methods: bool,
9}
10
11impl Policy {
12 pub fn allow<B>(&self, request: &http::Request<B>) -> bool {
15 if self.allow_safe_methods
16 && method_in(
17 request.method(),
18 [Method::GET, Method::HEAD, Method::OPTIONS],
19 )
20 {
21 return true;
22 }
23
24 let sec_fetch_site = request.headers().get(header::SEC_FETCH_SITE);
25 let sec_fetch_mode = request.headers().get(header::SEC_FETCH_MODE);
26 let sec_fetch_dest = request.headers().get(header::SEC_FETCH_DEST);
27
28 let sec_fetch = zip3(sec_fetch_site, sec_fetch_mode, sec_fetch_dest);
29
30 let Some((sec_fetch_site, sec_fetch_mode, sec_fetch_dest)) = sec_fetch else {
31 return !self.reject_missing_metadata;
34 };
35
36 if header_in(sec_fetch_site, ["same-origin", "same-site", "none"]) {
37 return true;
39 }
40
41 if sec_fetch_mode == "navigate"
42 && request.method() == Method::GET
43 && !header_in(sec_fetch_dest, ["object", "embed"])
44 {
45 return true;
47 }
48
49 false
51 }
52}
53
54pub struct PolicyBuilder {
56 reject_missing_metadata: bool,
57 allow_safe_methods: bool,
58}
59
60impl PolicyBuilder {
61 pub(crate) fn new() -> Self {
62 Self {
63 reject_missing_metadata: false,
64 allow_safe_methods: false,
65 }
66 }
67
68 pub fn reject_missing_metadata(&mut self) -> &mut Self {
71 self.reject_missing_metadata = true;
72 self
73 }
74
75 pub fn allow_safe_methods(&mut self) -> &mut Self {
77 self.allow_safe_methods = true;
78 self
79 }
80
81 pub(crate) fn build(self) -> Policy {
82 Policy {
83 reject_missing_metadata: self.reject_missing_metadata,
84 allow_safe_methods: self.allow_safe_methods,
85 }
86 }
87}
88
89fn zip3<T1, T2, T3>(a: Option<T1>, b: Option<T2>, c: Option<T3>) -> Option<(T1, T2, T3)> {
90 match (a, b, c) {
91 (Some(a), Some(b), Some(c)) => Some((a, b, c)),
92 _ => None,
93 }
94}
95
96fn header_in(header: &HeaderValue, values: impl IntoIterator<Item = &'static str>) -> bool {
97 values
98 .into_iter()
99 .map(HeaderValue::from_static)
100 .any(|value| value == header)
101}
102
103fn method_in(method: &Method, values: impl IntoIterator<Item = Method>) -> bool {
104 values.into_iter().any(|value| value == method)
105}