tower_http/follow_redirect/policy/
filter_credentials.rs

1use super::{eq_origin, Action, Attempt, Policy};
2use http::{
3    header::{self, HeaderName},
4    Request,
5};
6
7/// A redirection [`Policy`] that removes credentials from requests in redirections.
8#[derive(Clone, Debug)]
9pub struct FilterCredentials {
10    block_cross_origin: bool,
11    block_any: bool,
12    remove_blocklisted: bool,
13    remove_all: bool,
14    blocked: bool,
15}
16
17const BLOCKLIST: &[HeaderName] = &[
18    header::AUTHORIZATION,
19    header::COOKIE,
20    header::PROXY_AUTHORIZATION,
21];
22
23impl FilterCredentials {
24    /// Create a new [`FilterCredentials`] that removes blocklisted request headers in cross-origin
25    /// redirections.
26    pub fn new() -> Self {
27        FilterCredentials {
28            block_cross_origin: true,
29            block_any: false,
30            remove_blocklisted: true,
31            remove_all: false,
32            blocked: false,
33        }
34    }
35
36    /// Configure `self` to mark cross-origin redirections as "blocked".
37    pub fn block_cross_origin(mut self, enable: bool) -> Self {
38        self.block_cross_origin = enable;
39        self
40    }
41
42    /// Configure `self` to mark every redirection as "blocked".
43    pub fn block_any(mut self) -> Self {
44        self.block_any = true;
45        self
46    }
47
48    /// Configure `self` to mark no redirections as "blocked".
49    pub fn block_none(mut self) -> Self {
50        self.block_any = false;
51        self.block_cross_origin(false)
52    }
53
54    /// Configure `self` to remove blocklisted headers in "blocked" redirections.
55    ///
56    /// The blocklist includes the following headers:
57    ///
58    /// - `Authorization`
59    /// - `Cookie`
60    /// - `Proxy-Authorization`
61    pub fn remove_blocklisted(mut self, enable: bool) -> Self {
62        self.remove_blocklisted = enable;
63        self
64    }
65
66    /// Configure `self` to remove all headers in "blocked" redirections.
67    pub fn remove_all(mut self) -> Self {
68        self.remove_all = true;
69        self
70    }
71
72    /// Configure `self` to remove no headers in "blocked" redirections.
73    pub fn remove_none(mut self) -> Self {
74        self.remove_all = false;
75        self.remove_blocklisted(false)
76    }
77}
78
79impl Default for FilterCredentials {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl<B, E> Policy<B, E> for FilterCredentials {
86    fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
87        self.blocked = self.block_any
88            || (self.block_cross_origin && !eq_origin(attempt.previous(), attempt.location()));
89        Ok(Action::Follow)
90    }
91
92    fn on_request(&mut self, request: &mut Request<B>) {
93        if self.blocked {
94            let headers = request.headers_mut();
95            if self.remove_all {
96                headers.clear();
97            } else if self.remove_blocklisted {
98                for key in BLOCKLIST {
99                    headers.remove(key);
100                }
101            }
102        }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use http::Uri;
110
111    #[test]
112    fn works() {
113        let mut policy = FilterCredentials::default();
114
115        let initial = Uri::from_static("http://example.com/old");
116        let same_origin = Uri::from_static("http://example.com/new");
117        let cross_origin = Uri::from_static("https://example.com/new");
118
119        let mut request = Request::builder()
120            .uri(initial)
121            .header(header::COOKIE, "42")
122            .body(())
123            .unwrap();
124        Policy::<(), ()>::on_request(&mut policy, &mut request);
125        assert!(request.headers().contains_key(header::COOKIE));
126
127        let attempt = Attempt {
128            status: Default::default(),
129            location: &same_origin,
130            previous: request.uri(),
131        };
132        assert!(Policy::<(), ()>::redirect(&mut policy, &attempt)
133            .unwrap()
134            .is_follow());
135
136        let mut request = Request::builder()
137            .uri(same_origin)
138            .header(header::COOKIE, "42")
139            .body(())
140            .unwrap();
141        Policy::<(), ()>::on_request(&mut policy, &mut request);
142        assert!(request.headers().contains_key(header::COOKIE));
143
144        let attempt = Attempt {
145            status: Default::default(),
146            location: &cross_origin,
147            previous: request.uri(),
148        };
149        assert!(Policy::<(), ()>::redirect(&mut policy, &attempt)
150            .unwrap()
151            .is_follow());
152
153        let mut request = Request::builder()
154            .uri(cross_origin)
155            .header(header::COOKIE, "42")
156            .body(())
157            .unwrap();
158        Policy::<(), ()>::on_request(&mut policy, &mut request);
159        assert!(!request.headers().contains_key(header::COOKIE));
160    }
161}