tower_http/follow_redirect/policy/
filter_credentials.rs1use super::{eq_origin, Action, Attempt, Policy};
2use http::{
3 header::{self, HeaderName},
4 Request,
5};
6
7#[derive(Clone, Debug)]
9pub struct FilterCredentials {
10 block_cross_origin: bool,
11 block_any: bool,
12 remove_blocklisted: bool,
13 remove_all: bool,
14 blocked: bool,
15}
16
17const BLOCKLIST: &[HeaderName] = &[
18 header::AUTHORIZATION,
19 header::COOKIE,
20 header::PROXY_AUTHORIZATION,
21];
22
23impl FilterCredentials {
24 pub fn new() -> Self {
27 FilterCredentials {
28 block_cross_origin: true,
29 block_any: false,
30 remove_blocklisted: true,
31 remove_all: false,
32 blocked: false,
33 }
34 }
35
36 pub fn block_cross_origin(mut self, enable: bool) -> Self {
38 self.block_cross_origin = enable;
39 self
40 }
41
42 pub fn block_any(mut self) -> Self {
44 self.block_any = true;
45 self
46 }
47
48 pub fn block_none(mut self) -> Self {
50 self.block_any = false;
51 self.block_cross_origin(false)
52 }
53
54 pub fn remove_blocklisted(mut self, enable: bool) -> Self {
62 self.remove_blocklisted = enable;
63 self
64 }
65
66 pub fn remove_all(mut self) -> Self {
68 self.remove_all = true;
69 self
70 }
71
72 pub fn remove_none(mut self) -> Self {
74 self.remove_all = false;
75 self.remove_blocklisted(false)
76 }
77}
78
79impl Default for FilterCredentials {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl<B, E> Policy<B, E> for FilterCredentials {
86 fn redirect(&mut self, attempt: &Attempt<'_>) -> Result<Action, E> {
87 self.blocked = self.block_any
88 || (self.block_cross_origin && !eq_origin(attempt.previous(), attempt.location()));
89 Ok(Action::Follow)
90 }
91
92 fn on_request(&mut self, request: &mut Request<B>) {
93 if self.blocked {
94 let headers = request.headers_mut();
95 if self.remove_all {
96 headers.clear();
97 } else if self.remove_blocklisted {
98 for key in BLOCKLIST {
99 headers.remove(key);
100 }
101 }
102 }
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use http::Uri;
110
111 #[test]
112 fn works() {
113 let mut policy = FilterCredentials::default();
114
115 let initial = Uri::from_static("http://example.com/old");
116 let same_origin = Uri::from_static("http://example.com/new");
117 let cross_origin = Uri::from_static("https://example.com/new");
118
119 let mut request = Request::builder()
120 .uri(initial)
121 .header(header::COOKIE, "42")
122 .body(())
123 .unwrap();
124 Policy::<(), ()>::on_request(&mut policy, &mut request);
125 assert!(request.headers().contains_key(header::COOKIE));
126
127 let attempt = Attempt {
128 status: Default::default(),
129 location: &same_origin,
130 previous: request.uri(),
131 };
132 assert!(Policy::<(), ()>::redirect(&mut policy, &attempt)
133 .unwrap()
134 .is_follow());
135
136 let mut request = Request::builder()
137 .uri(same_origin)
138 .header(header::COOKIE, "42")
139 .body(())
140 .unwrap();
141 Policy::<(), ()>::on_request(&mut policy, &mut request);
142 assert!(request.headers().contains_key(header::COOKIE));
143
144 let attempt = Attempt {
145 status: Default::default(),
146 location: &cross_origin,
147 previous: request.uri(),
148 };
149 assert!(Policy::<(), ()>::redirect(&mut policy, &attempt)
150 .unwrap()
151 .is_follow());
152
153 let mut request = Request::builder()
154 .uri(cross_origin)
155 .header(header::COOKIE, "42")
156 .body(())
157 .unwrap();
158 Policy::<(), ()>::on_request(&mut policy, &mut request);
159 assert!(!request.headers().contains_key(header::COOKIE));
160 }
161}